Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir] Start moving some builtin type formats to the dialect #80421

Open
wants to merge 1 commit into
base: users/zero9178/qualified-trait
Choose a base branch
from

Conversation

zero9178
Copy link
Member

@zero9178 zero9178 commented Feb 2, 2024

Most types and attributes in the builtin dialect are parsed and printed using special-purpose printers and parsers for that type. They also use the low-level Printer rather than the AsmPrinter, making the implementations inconsistent compared to all other dialects in MLIR.

This PR starts moving some builtin types to be parsed using the usual print and parse methods like all other MLIR dialects. This has the following advantages:

  • The implementation now looks like any other dialect's
  • It is now possible to use assemblyFormat for builtin types and attributes
  • The code can be easily moved to other dialects if desired
  • Arguably better layering and less code
  • As a side-effect, it is now also possible to write !builtin.<type> for any types if desired

A future benefit would include being able to print types and attributes in stripped format as well (e.g. <f32> vs complex<f32>), just like all other dialect types and attributes. This is currently explicitly disabled as it causes a LOT of changes in IR syntax and I believe some ambiguities in the parser.

For the purpose of reviewing and incremental development, this PR only moves tuple, tensor, none, memref and complex. The plan is to eventually move all attributes and types where the current syntax can be implemented within the dialect.

For backwards compatibility with the existing syntax, the builtin dialect is special-cased in the printer where the builtin. prefix is omitted.

Depends on #80420

@llvmbot
Copy link
Collaborator

llvmbot commented Feb 2, 2024

@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir-ods

Author: Markus Böck (zero9178)

Changes

Most types and attributes in the builtin dialect are parsed and printed using special-purpose printers and parsers for that type. They also use the low-level Printer rather than the AsmPrinter, making the implementations inconsistent compared to all other dialects in MLIR.

This PR starts moving some builtin types to be parsed using the usual print and parse methods like all other MLIR dialects. This has the following advantages:

  • The implementation now looks like any other dialect's
  • It is now possible to use assemblyFormat for builtin types and attributes
  • The code can be easily moved to other dialects if desired
  • Arguably better layering and less code
  • As a side-effect, it is now also possible to write !builtin.&lt;type&gt; for any types if desired

A future benefit would include being able to print types and attributes in stripped format as well (e.g. &lt;f32&gt; vs complex&lt;f32&gt;), just like all other dialect types and attributes. This is currently explicitly disabled as it causes a LOT of changes in IR syntax and I believe some ambiguities in the parser.

For the purpose of reviewing and incremental development, this PR only moves tuple, tensor, none, memref and complex. The plan is to eventually move all attributes and types where the current syntax can be implemented within the dialect.

For backwards compatibility with the existing syntax, the builtin dialect is special-cased in the printer where the builtin. prefix is omitted.

Depends on #80420


Patch is 28.53 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/80421.diff

11 Files Affected:

  • (modified) mlir/include/mlir/IR/BuiltinDialect.td (+1-1)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.td (+23-2)
  • (modified) mlir/include/mlir/IR/OpImplementation.h (+5)
  • (modified) mlir/lib/AsmParser/DialectSymbolParser.cpp (+24)
  • (modified) mlir/lib/AsmParser/Parser.h (+11-13)
  • (modified) mlir/lib/AsmParser/TypeParser.cpp (+3-208)
  • (modified) mlir/lib/IR/AsmPrinter.cpp (+14-58)
  • (modified) mlir/lib/IR/BuiltinTypes.cpp (+150)
  • (modified) mlir/test/IR/invalid-builtin-types.mlir (+5-5)
  • (modified) mlir/test/IR/invalid.mlir (+2-2)
  • (added) mlir/test/IR/qualified-builtin.mlir (+11)
diff --git a/mlir/include/mlir/IR/BuiltinDialect.td b/mlir/include/mlir/IR/BuiltinDialect.td
index c131107634b44..a8627170288c9 100644
--- a/mlir/include/mlir/IR/BuiltinDialect.td
+++ b/mlir/include/mlir/IR/BuiltinDialect.td
@@ -22,7 +22,7 @@ def Builtin_Dialect : Dialect {
   let name = "builtin";
   let cppNamespace = "::mlir";
   let useDefaultAttributePrinterParser = 0;
-  let useDefaultTypePrinterParser = 0;
+  let useDefaultTypePrinterParser = 1;
   let extraClassDeclaration = [{
   private:
     // Register the builtin Attributes.
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 4cade83dd3c32..f3a51d2155040 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -25,7 +25,8 @@ include "mlir/IR/BuiltinTypeInterfaces.td"
 // Base class for Builtin dialect types.
 class Builtin_Type<string name, string typeMnemonic, list<Trait> traits = [],
                    string baseCppClass = "::mlir::Type">
-    : TypeDef<Builtin_Dialect, name, traits, baseCppClass> {
+    : TypeDef<Builtin_Dialect, name, !listconcat(traits, [PrintTypeQualified]),
+        baseCppClass> {
   let mnemonic = ?;
   let typeName = "builtin." # typeMnemonic;
 }
@@ -62,6 +63,9 @@ def Builtin_Complex : Builtin_Type<"Complex", "complex"> {
   ];
   let skipDefaultBuilders = 1;
   let genVerifyDecl = 1;
+
+  let mnemonic = "complex";
+  let assemblyFormat = "`<` $elementType `>`";
 }
 
 //===----------------------------------------------------------------------===//
@@ -668,6 +672,9 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
   }];
   let skipDefaultBuilders = 1;
   let genVerifyDecl = 1;
+
+  let mnemonic = "memref";
+  let hasCustomAssemblyFormat = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -698,6 +705,8 @@ def Builtin_None : Builtin_Type<"None", "none"> {
   let extraClassDeclaration = [{
     static NoneType get(MLIRContext *context);
   }];
+
+  let mnemonic = "none";
 }
 
 //===----------------------------------------------------------------------===//
@@ -849,6 +858,9 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
   }];
   let skipDefaultBuilders = 1;
   let genVerifyDecl = 1;
+
+  let mnemonic = "tensor";
+  let hasCustomAssemblyFormat = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -884,7 +896,7 @@ def Builtin_Tuple : Builtin_Type<"Tuple", "tuple"> {
     tuple<i32, f32, tensor<i1>, i5>
     ```
   }];
-  let parameters = (ins "ArrayRef<Type>":$types);
+  let parameters = (ins OptionalArrayRefParameter<"Type">:$types);
   let builders = [
     TypeBuilder<(ins "TypeRange":$elementTypes), [{
       return $_get($_ctxt, elementTypes);
@@ -916,6 +928,9 @@ def Builtin_Tuple : Builtin_Type<"Tuple", "tuple"> {
       return getTypes()[index];
     }
   }];
+
+  let mnemonic = "tuple";
+  let assemblyFormat = "`<` (`>`) : ($types^ `>`)?";
 }
 
 //===----------------------------------------------------------------------===//
@@ -994,6 +1009,9 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
   }];
   let skipDefaultBuilders = 1;
   let genVerifyDecl = 1;
+
+  let mnemonic = "memref";
+  let hasCustomAssemblyFormat = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -1043,6 +1061,9 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
   }];
   let skipDefaultBuilders = 1;
   let genVerifyDecl = 1;
+
+  let mnemonic = "tensor";
+  let hasCustomAssemblyFormat = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 402399cf29665..de6245eea862a 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -186,6 +186,11 @@ class AsmPrinter {
   /// provide a valid type for the attribute.
   virtual void printAttributeWithoutType(Attribute attr);
 
+  /// Print the given attribute without its type if and only if the type is the
+  /// default type for the given attribute.
+  /// E.g. '1 : i64' is printed as just '1'.
+  virtual void printAttributeWithoutDefaultType(Attribute attr);
+
   /// Print the alias for the given attribute, return failure if no alias could
   /// be printed.
   virtual LogicalResult printAlias(Attribute attr);
diff --git a/mlir/lib/AsmParser/DialectSymbolParser.cpp b/mlir/lib/AsmParser/DialectSymbolParser.cpp
index 80cce7e6ae43f..400d26398afc3 100644
--- a/mlir/lib/AsmParser/DialectSymbolParser.cpp
+++ b/mlir/lib/AsmParser/DialectSymbolParser.cpp
@@ -322,6 +322,30 @@ Type Parser::parseExtendedType() {
       });
 }
 
+Type Parser::parseExtendedBuiltinType() {
+  // Initially set to just the mnemonic of the type.
+  llvm::StringRef symbolData = getToken().getSpelling();
+  const char *startOfTypePos = symbolData.data();
+  consumeToken();
+  // Extend 'symbolData' to include the body if it is not a singleton type.
+  // Note that all types in the builtin type always use the pretty dialect form
+  // aka 'dialect.mnemonic<body>'.
+  if (getToken().is(Token::less))
+    if (failed(parseDialectSymbolBody(symbolData)))
+      return nullptr;
+
+  const char *endOfTypePos = getToken().getLoc().getPointer();
+
+  // With the body of the type captured, hand it off to the dialect parser.
+  resetToken(startOfTypePos);
+  CustomDialectAsmParser customParser(symbolData, *this);
+  Type type = builtinDialect->parseType(customParser);
+
+  // Move the lexer past the type.
+  resetToken(endOfTypePos);
+  return type;
+}
+
 //===----------------------------------------------------------------------===//
 // mlir::parseAttribute/parseType
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h
index b959e67b8e258..73080c88ff6b0 100644
--- a/mlir/lib/AsmParser/Parser.h
+++ b/mlir/lib/AsmParser/Parser.h
@@ -11,6 +11,7 @@
 
 #include "ParserState.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinDialect.h"
 #include "mlir/IR/OpImplementation.h"
 #include <optional>
 
@@ -28,9 +29,14 @@ class Parser {
   using Delimiter = OpAsmParser::Delimiter;
 
   Builder builder;
+  /// Cached instance of the builtin dialect for parsing builtins.
+  Dialect *builtinDialect;
 
   Parser(ParserState &state)
-      : builder(state.config.getContext()), state(state) {}
+      : builder(state.config.getContext()),
+        builtinDialect(
+            builder.getContext()->getLoadedDialect<BuiltinDialect>()),
+        state(state) {}
 
   // Helper methods to get stuff from the parser-global state.
   ParserState &getState() const { return state; }
@@ -192,27 +198,19 @@ class Parser {
   /// Parse an arbitrary type.
   Type parseType();
 
-  /// Parse a complex type.
-  Type parseComplexType();
-
   /// Parse an extended type.
   Type parseExtendedType();
 
+  /// Parse an extended type from the builtin dialect where the '!builtin'
+  /// prefix is missing.
+  Type parseExtendedBuiltinType();
+
   /// Parse a function type.
   Type parseFunctionType();
 
-  /// Parse a memref type.
-  Type parseMemRefType();
-
   /// Parse a non function type.
   Type parseNonFunctionType();
 
-  /// Parse a tensor type.
-  Type parseTensorType();
-
-  /// Parse a tuple type.
-  Type parseTupleType();
-
   /// Parse a vector type.
   VectorType parseVectorType();
   ParseResult parseVectorDimensionList(SmallVectorImpl<int64_t> &dimensions,
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 5da931b77b3be..95df69b899b8a 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -11,12 +11,9 @@
 //===----------------------------------------------------------------------===//
 
 #include "Parser.h"
-#include "mlir/IR/AffineMap.h"
-#include "mlir/IR/BuiltinAttributeInterfaces.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/OpDefinition.h"
-#include "mlir/IR/TensorEncoding.h"
 #include "mlir/IR/Types.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/LogicalResult.h"
@@ -123,29 +120,6 @@ ParseResult Parser::parseTypeListParens(SmallVectorImpl<Type> &elements) {
   return success();
 }
 
-/// Parse a complex type.
-///
-///   complex-type ::= `complex` `<` type `>`
-///
-Type Parser::parseComplexType() {
-  consumeToken(Token::kw_complex);
-
-  // Parse the '<'.
-  if (parseToken(Token::less, "expected '<' in complex type"))
-    return nullptr;
-
-  SMLoc elementTypeLoc = getToken().getLoc();
-  auto elementType = parseType();
-  if (!elementType ||
-      parseToken(Token::greater, "expected '>' in complex type"))
-    return nullptr;
-  if (!isa<FloatType>(elementType) && !isa<IntegerType>(elementType))
-    return emitError(elementTypeLoc, "invalid element type for complex"),
-           nullptr;
-
-  return ComplexType::get(elementType);
-}
-
 /// Parse a function type.
 ///
 ///   function-type ::= type-list-parens `->` function-result-type
@@ -162,95 +136,6 @@ Type Parser::parseFunctionType() {
   return builder.getFunctionType(arguments, results);
 }
 
-/// Parse a memref type.
-///
-///   memref-type ::= ranked-memref-type | unranked-memref-type
-///
-///   ranked-memref-type ::= `memref` `<` dimension-list-ranked type
-///                          (`,` layout-specification)? (`,` memory-space)? `>`
-///
-///   unranked-memref-type ::= `memref` `<*x` type (`,` memory-space)? `>`
-///
-///   stride-list ::= `[` (dimension (`,` dimension)*)? `]`
-///   strided-layout ::= `offset:` dimension `,` `strides: ` stride-list
-///   layout-specification ::= semi-affine-map | strided-layout | attribute
-///   memory-space ::= integer-literal | attribute
-///
-Type Parser::parseMemRefType() {
-  SMLoc loc = getToken().getLoc();
-  consumeToken(Token::kw_memref);
-
-  if (parseToken(Token::less, "expected '<' in memref type"))
-    return nullptr;
-
-  bool isUnranked;
-  SmallVector<int64_t, 4> dimensions;
-
-  if (consumeIf(Token::star)) {
-    // This is an unranked memref type.
-    isUnranked = true;
-    if (parseXInDimensionList())
-      return nullptr;
-
-  } else {
-    isUnranked = false;
-    if (parseDimensionListRanked(dimensions))
-      return nullptr;
-  }
-
-  // Parse the element type.
-  auto typeLoc = getToken().getLoc();
-  auto elementType = parseType();
-  if (!elementType)
-    return nullptr;
-
-  // Check that memref is formed from allowed types.
-  if (!BaseMemRefType::isValidElementType(elementType))
-    return emitError(typeLoc, "invalid memref element type"), nullptr;
-
-  MemRefLayoutAttrInterface layout;
-  Attribute memorySpace;
-
-  auto parseElt = [&]() -> ParseResult {
-    // Either it is MemRefLayoutAttrInterface or memory space attribute.
-    Attribute attr = parseAttribute();
-    if (!attr)
-      return failure();
-
-    if (isa<MemRefLayoutAttrInterface>(attr)) {
-      layout = cast<MemRefLayoutAttrInterface>(attr);
-    } else if (memorySpace) {
-      return emitError("multiple memory spaces specified in memref type");
-    } else {
-      memorySpace = attr;
-      return success();
-    }
-
-    if (isUnranked)
-      return emitError("cannot have affine map for unranked memref type");
-    if (memorySpace)
-      return emitError("expected memory space to be last in memref type");
-
-    return success();
-  };
-
-  // Parse a list of mappings and address space if present.
-  if (!consumeIf(Token::greater)) {
-    // Parse comma separated list of affine maps, followed by memory space.
-    if (parseToken(Token::comma, "expected ',' or '>' in memref type") ||
-        parseCommaSeparatedListUntil(Token::greater, parseElt,
-                                     /*allowEmptyList=*/false)) {
-      return nullptr;
-    }
-  }
-
-  if (isUnranked)
-    return getChecked<UnrankedMemRefType>(loc, elementType, memorySpace);
-
-  return getChecked<MemRefType>(loc, dimensions, elementType, layout,
-                                memorySpace);
-}
-
 /// Parse any type except the function type.
 ///
 ///   non-function-type ::= integer-type
@@ -272,14 +157,12 @@ Type Parser::parseNonFunctionType() {
   switch (getToken().getKind()) {
   default:
     return (emitWrongTokenError("expected non-function type"), nullptr);
-  case Token::kw_memref:
-    return parseMemRefType();
   case Token::kw_tensor:
-    return parseTensorType();
+  case Token::kw_memref:
   case Token::kw_complex:
-    return parseComplexType();
   case Token::kw_tuple:
-    return parseTupleType();
+  case Token::kw_none:
+    return parseExtendedBuiltinType();
   case Token::kw_vector:
     return parseVectorType();
   // integer-type
@@ -344,11 +227,6 @@ Type Parser::parseNonFunctionType() {
     consumeToken(Token::kw_index);
     return builder.getIndexType();
 
-  // none-type
-  case Token::kw_none:
-    consumeToken(Token::kw_none);
-    return builder.getNoneType();
-
   // extended type
   case Token::exclamation_identifier:
     return parseExtendedType();
@@ -361,89 +239,6 @@ Type Parser::parseNonFunctionType() {
   }
 }
 
-/// Parse a tensor type.
-///
-///   tensor-type ::= `tensor` `<` dimension-list type `>`
-///   dimension-list ::= dimension-list-ranked | `*x`
-///
-Type Parser::parseTensorType() {
-  consumeToken(Token::kw_tensor);
-
-  if (parseToken(Token::less, "expected '<' in tensor type"))
-    return nullptr;
-
-  bool isUnranked;
-  SmallVector<int64_t, 4> dimensions;
-
-  if (consumeIf(Token::star)) {
-    // This is an unranked tensor type.
-    isUnranked = true;
-
-    if (parseXInDimensionList())
-      return nullptr;
-
-  } else {
-    isUnranked = false;
-    if (parseDimensionListRanked(dimensions))
-      return nullptr;
-  }
-
-  // Parse the element type.
-  auto elementTypeLoc = getToken().getLoc();
-  auto elementType = parseType();
-
-  // Parse an optional encoding attribute.
-  Attribute encoding;
-  if (consumeIf(Token::comma)) {
-    auto parseResult = parseOptionalAttribute(encoding);
-    if (parseResult.has_value()) {
-      if (failed(parseResult.value()))
-        return nullptr;
-      if (auto v = dyn_cast_or_null<VerifiableTensorEncoding>(encoding)) {
-        if (failed(v.verifyEncoding(dimensions, elementType,
-                                    [&] { return emitError(); })))
-          return nullptr;
-      }
-    }
-  }
-
-  if (!elementType || parseToken(Token::greater, "expected '>' in tensor type"))
-    return nullptr;
-  if (!TensorType::isValidElementType(elementType))
-    return emitError(elementTypeLoc, "invalid tensor element type"), nullptr;
-
-  if (isUnranked) {
-    if (encoding)
-      return emitError("cannot apply encoding to unranked tensor"), nullptr;
-    return UnrankedTensorType::get(elementType);
-  }
-  return RankedTensorType::get(dimensions, elementType, encoding);
-}
-
-/// Parse a tuple type.
-///
-///   tuple-type ::= `tuple` `<` (type (`,` type)*)? `>`
-///
-Type Parser::parseTupleType() {
-  consumeToken(Token::kw_tuple);
-
-  // Parse the '<'.
-  if (parseToken(Token::less, "expected '<' in tuple type"))
-    return nullptr;
-
-  // Check for an empty tuple by directly parsing '>'.
-  if (consumeIf(Token::greater))
-    return TupleType::get(getContext());
-
-  // Parse the element types and the '>'.
-  SmallVector<Type, 4> types;
-  if (parseTypeListNoParens(types) ||
-      parseToken(Token::greater, "expected '>' in tuple type"))
-    return nullptr;
-
-  return TupleType::get(getContext(), types);
-}
-
 /// Parse a vector type.
 ///
 /// vector-type ::= `vector` `<` vector-dim-list vector-element-type `>`
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 6b8b7473bf0f8..0679d4135048a 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2132,6 +2132,13 @@ static bool isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName) {
 /// Print the given dialect symbol to the stream.
 static void printDialectSymbol(raw_ostream &os, StringRef symPrefix,
                                StringRef dialectName, StringRef symString) {
+  // Treat the builtin dialect special by eliding the '<symPrefix>builtin'
+  // prefix.
+  if (dialectName == "builtin") {
+    os << symString;
+    return;
+  }
+
   os << symPrefix << dialectName;
 
   // If this symbol name is simple enough, print it directly in pretty form,
@@ -2599,64 +2606,6 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
         printType(vectorTy.getElementType());
         os << '>';
       })
-      .Case<RankedTensorType>([&](RankedTensorType tensorTy) {
-        os << "tensor<";
-        printDimensionList(tensorTy.getShape());
-        if (!tensorTy.getShape().empty())
-          os << 'x';
-        printType(tensorTy.getElementType());
-        // Only print the encoding attribute value if set.
-        if (tensorTy.getEncoding()) {
-          os << ", ";
-          printAttribute(tensorTy.getEncoding());
-        }
-        os << '>';
-      })
-      .Case<UnrankedTensorType>([&](UnrankedTensorType tensorTy) {
-        os << "tensor<*x";
-        printType(tensorTy.getElementType());
-        os << '>';
-      })
-      .Case<MemRefType>([&](MemRefType memrefTy) {
-        os << "memref<";
-        printDimensionList(memrefTy.getShape());
-        if (!memrefTy.getShape().empty())
-          os << 'x';
-        printType(memrefTy.getElementType());
-        MemRefLayoutAttrInterface layout = memrefTy.getLayout();
-        if (!llvm::isa<AffineMapAttr>(layout) || !layout.isIdentity()) {
-          os << ", ";
-          printAttribute(memrefTy.getLayout(), AttrTypeElision::May);
-        }
-        // Only print the memory space if it is the non-default one.
-        if (memrefTy.getMemorySpace()) {
-          os << ", ";
-          printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May);
-        }
-        os << '>';
-      })
-      .Case<UnrankedMemRefType>([&](UnrankedMemRefType memrefTy) {
-        os << "memref<*x";
-        printType(memrefTy.getElementType());
-        // Only print the memory space if it is the non-default one.
-        if (memrefTy.getMemorySpace()) {
-          os << ", ";
-          printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May);
-        }
-        os << '>';
-      })
-      .Case<ComplexType>([&](ComplexType complexTy) {
-        os << "complex<";
-        printType(complexTy.getElementType());
-        os << '>';
-      })
-      .Case<TupleType>([&](TupleType tupleTy) {
-        os << "tuple<";
-        interleaveComma(tupleTy.getTypes(),
-                        [&](Type type) { printType(type); });
-        os << '>';
-      })
-      .Case<NoneType>([&](Type) { os << "none"; })
       .Default([&](Type type) { return printDialectType(type); });
 }
 
@@ -2799,6 +2748,13 @@ void AsmPrinter::printAttributeWithoutType(Attribute attr) {
   impl->printAttribute(attr, Impl::AttrTypeElision::Must);
 }
 
+void AsmPrinter::printAttributeWithoutDefaultType(Attribute attr) {
+  assert(
+      impl &&
+      "expected AsmPrinter::printAttributeWithoutDefaultType to be overriden");
+  impl->printAttribute(attr, Impl::AttrTypeElision::May);
+}
+
 void AsmPrinter::printKeywordOrString(StringRef keyword) {
   assert(impl && "expected AsmPrinter::printKeywordOrString to be overriden");
   ::printKeywordOrString(keyword, impl->getStream());
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 9b8ee3d452803..e160c0ff4c33d 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -10,10 +10,13 @@
 #include "TypeDetail.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinDialect.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Dialect.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/TensorEncoding.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "llvm/ADT/APFloat.h"
@@ -25,6 +28,52 @@
 using namespace mlir;
 using namespace mlir::detail;
 
+//===----------------------------------------------------------------------===//
+// Custom printing and parsing
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseMemRe...
[truncated]

@llvmbot
Copy link
Collaborator

llvmbot commented Feb 2, 2024

@llvm/pr-subscribers-mlir

Author: Markus Böck (zero9178)

Changes

Most types and attributes in the builtin dialect are parsed and printed using special-purpose printers and parsers for that type. They also use the low-level Printer rather than the AsmPrinter, making the implementations inconsistent compared to all other dialects in MLIR.

This PR starts moving some builtin types to be parsed using the usual print and parse methods like all other MLIR dialects. This has the following advantages:

  • The implementation now looks like any other dialect's
  • It is now possible to use assemblyFormat for builtin types and attributes
  • The code can be easily moved to other dialects if desired
  • Arguably better layering and less code
  • As a side-effect, it is now also possible to write !builtin.&lt;type&gt; for any types if desired

A future benefit would include being able to print types and attributes in stripped format as well (e.g. &lt;f32&gt; vs complex&lt;f32&gt;), just like all other dialect types and attributes. This is currently explicitly disabled as it causes a LOT of changes in IR syntax and I believe some ambiguities in the parser.

For the purpose of reviewing and incremental development, this PR only moves tuple, tensor, none, memref and complex. The plan is to eventually move all attributes and types where the current syntax can be implemented within the dialect.

For backwards compatibility with the existing syntax, the builtin dialect is special-cased in the printer where the builtin. prefix is omitted.

Depends on #80420


Patch is 28.53 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/80421.diff

11 Files Affected:

  • (modified) mlir/include/mlir/IR/BuiltinDialect.td (+1-1)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.td (+23-2)
  • (modified) mlir/include/mlir/IR/OpImplementation.h (+5)
  • (modified) mlir/lib/AsmParser/DialectSymbolParser.cpp (+24)
  • (modified) mlir/lib/AsmParser/Parser.h (+11-13)
  • (modified) mlir/lib/AsmParser/TypeParser.cpp (+3-208)
  • (modified) mlir/lib/IR/AsmPrinter.cpp (+14-58)
  • (modified) mlir/lib/IR/BuiltinTypes.cpp (+150)
  • (modified) mlir/test/IR/invalid-builtin-types.mlir (+5-5)
  • (modified) mlir/test/IR/invalid.mlir (+2-2)
  • (added) mlir/test/IR/qualified-builtin.mlir (+11)
diff --git a/mlir/include/mlir/IR/BuiltinDialect.td b/mlir/include/mlir/IR/BuiltinDialect.td
index c131107634b44..a8627170288c9 100644
--- a/mlir/include/mlir/IR/BuiltinDialect.td
+++ b/mlir/include/mlir/IR/BuiltinDialect.td
@@ -22,7 +22,7 @@ def Builtin_Dialect : Dialect {
   let name = "builtin";
   let cppNamespace = "::mlir";
   let useDefaultAttributePrinterParser = 0;
-  let useDefaultTypePrinterParser = 0;
+  let useDefaultTypePrinterParser = 1;
   let extraClassDeclaration = [{
   private:
     // Register the builtin Attributes.
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 4cade83dd3c32..f3a51d2155040 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -25,7 +25,8 @@ include "mlir/IR/BuiltinTypeInterfaces.td"
 // Base class for Builtin dialect types.
 class Builtin_Type<string name, string typeMnemonic, list<Trait> traits = [],
                    string baseCppClass = "::mlir::Type">
-    : TypeDef<Builtin_Dialect, name, traits, baseCppClass> {
+    : TypeDef<Builtin_Dialect, name, !listconcat(traits, [PrintTypeQualified]),
+        baseCppClass> {
   let mnemonic = ?;
   let typeName = "builtin." # typeMnemonic;
 }
@@ -62,6 +63,9 @@ def Builtin_Complex : Builtin_Type<"Complex", "complex"> {
   ];
   let skipDefaultBuilders = 1;
   let genVerifyDecl = 1;
+
+  let mnemonic = "complex";
+  let assemblyFormat = "`<` $elementType `>`";
 }
 
 //===----------------------------------------------------------------------===//
@@ -668,6 +672,9 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
   }];
   let skipDefaultBuilders = 1;
   let genVerifyDecl = 1;
+
+  let mnemonic = "memref";
+  let hasCustomAssemblyFormat = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -698,6 +705,8 @@ def Builtin_None : Builtin_Type<"None", "none"> {
   let extraClassDeclaration = [{
     static NoneType get(MLIRContext *context);
   }];
+
+  let mnemonic = "none";
 }
 
 //===----------------------------------------------------------------------===//
@@ -849,6 +858,9 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
   }];
   let skipDefaultBuilders = 1;
   let genVerifyDecl = 1;
+
+  let mnemonic = "tensor";
+  let hasCustomAssemblyFormat = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -884,7 +896,7 @@ def Builtin_Tuple : Builtin_Type<"Tuple", "tuple"> {
     tuple<i32, f32, tensor<i1>, i5>
     ```
   }];
-  let parameters = (ins "ArrayRef<Type>":$types);
+  let parameters = (ins OptionalArrayRefParameter<"Type">:$types);
   let builders = [
     TypeBuilder<(ins "TypeRange":$elementTypes), [{
       return $_get($_ctxt, elementTypes);
@@ -916,6 +928,9 @@ def Builtin_Tuple : Builtin_Type<"Tuple", "tuple"> {
       return getTypes()[index];
     }
   }];
+
+  let mnemonic = "tuple";
+  let assemblyFormat = "`<` (`>`) : ($types^ `>`)?";
 }
 
 //===----------------------------------------------------------------------===//
@@ -994,6 +1009,9 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
   }];
   let skipDefaultBuilders = 1;
   let genVerifyDecl = 1;
+
+  let mnemonic = "memref";
+  let hasCustomAssemblyFormat = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -1043,6 +1061,9 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
   }];
   let skipDefaultBuilders = 1;
   let genVerifyDecl = 1;
+
+  let mnemonic = "tensor";
+  let hasCustomAssemblyFormat = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 402399cf29665..de6245eea862a 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -186,6 +186,11 @@ class AsmPrinter {
   /// provide a valid type for the attribute.
   virtual void printAttributeWithoutType(Attribute attr);
 
+  /// Print the given attribute without its type if and only if the type is the
+  /// default type for the given attribute.
+  /// E.g. '1 : i64' is printed as just '1'.
+  virtual void printAttributeWithoutDefaultType(Attribute attr);
+
   /// Print the alias for the given attribute, return failure if no alias could
   /// be printed.
   virtual LogicalResult printAlias(Attribute attr);
diff --git a/mlir/lib/AsmParser/DialectSymbolParser.cpp b/mlir/lib/AsmParser/DialectSymbolParser.cpp
index 80cce7e6ae43f..400d26398afc3 100644
--- a/mlir/lib/AsmParser/DialectSymbolParser.cpp
+++ b/mlir/lib/AsmParser/DialectSymbolParser.cpp
@@ -322,6 +322,30 @@ Type Parser::parseExtendedType() {
       });
 }
 
+Type Parser::parseExtendedBuiltinType() {
+  // Initially set to just the mnemonic of the type.
+  llvm::StringRef symbolData = getToken().getSpelling();
+  const char *startOfTypePos = symbolData.data();
+  consumeToken();
+  // Extend 'symbolData' to include the body if it is not a singleton type.
+  // Note that all types in the builtin type always use the pretty dialect form
+  // aka 'dialect.mnemonic<body>'.
+  if (getToken().is(Token::less))
+    if (failed(parseDialectSymbolBody(symbolData)))
+      return nullptr;
+
+  const char *endOfTypePos = getToken().getLoc().getPointer();
+
+  // With the body of the type captured, hand it off to the dialect parser.
+  resetToken(startOfTypePos);
+  CustomDialectAsmParser customParser(symbolData, *this);
+  Type type = builtinDialect->parseType(customParser);
+
+  // Move the lexer past the type.
+  resetToken(endOfTypePos);
+  return type;
+}
+
 //===----------------------------------------------------------------------===//
 // mlir::parseAttribute/parseType
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h
index b959e67b8e258..73080c88ff6b0 100644
--- a/mlir/lib/AsmParser/Parser.h
+++ b/mlir/lib/AsmParser/Parser.h
@@ -11,6 +11,7 @@
 
 #include "ParserState.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinDialect.h"
 #include "mlir/IR/OpImplementation.h"
 #include <optional>
 
@@ -28,9 +29,14 @@ class Parser {
   using Delimiter = OpAsmParser::Delimiter;
 
   Builder builder;
+  /// Cached instance of the builtin dialect for parsing builtins.
+  Dialect *builtinDialect;
 
   Parser(ParserState &state)
-      : builder(state.config.getContext()), state(state) {}
+      : builder(state.config.getContext()),
+        builtinDialect(
+            builder.getContext()->getLoadedDialect<BuiltinDialect>()),
+        state(state) {}
 
   // Helper methods to get stuff from the parser-global state.
   ParserState &getState() const { return state; }
@@ -192,27 +198,19 @@ class Parser {
   /// Parse an arbitrary type.
   Type parseType();
 
-  /// Parse a complex type.
-  Type parseComplexType();
-
   /// Parse an extended type.
   Type parseExtendedType();
 
+  /// Parse an extended type from the builtin dialect where the '!builtin'
+  /// prefix is missing.
+  Type parseExtendedBuiltinType();
+
   /// Parse a function type.
   Type parseFunctionType();
 
-  /// Parse a memref type.
-  Type parseMemRefType();
-
   /// Parse a non function type.
   Type parseNonFunctionType();
 
-  /// Parse a tensor type.
-  Type parseTensorType();
-
-  /// Parse a tuple type.
-  Type parseTupleType();
-
   /// Parse a vector type.
   VectorType parseVectorType();
   ParseResult parseVectorDimensionList(SmallVectorImpl<int64_t> &dimensions,
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 5da931b77b3be..95df69b899b8a 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -11,12 +11,9 @@
 //===----------------------------------------------------------------------===//
 
 #include "Parser.h"
-#include "mlir/IR/AffineMap.h"
-#include "mlir/IR/BuiltinAttributeInterfaces.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/OpDefinition.h"
-#include "mlir/IR/TensorEncoding.h"
 #include "mlir/IR/Types.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/LogicalResult.h"
@@ -123,29 +120,6 @@ ParseResult Parser::parseTypeListParens(SmallVectorImpl<Type> &elements) {
   return success();
 }
 
-/// Parse a complex type.
-///
-///   complex-type ::= `complex` `<` type `>`
-///
-Type Parser::parseComplexType() {
-  consumeToken(Token::kw_complex);
-
-  // Parse the '<'.
-  if (parseToken(Token::less, "expected '<' in complex type"))
-    return nullptr;
-
-  SMLoc elementTypeLoc = getToken().getLoc();
-  auto elementType = parseType();
-  if (!elementType ||
-      parseToken(Token::greater, "expected '>' in complex type"))
-    return nullptr;
-  if (!isa<FloatType>(elementType) && !isa<IntegerType>(elementType))
-    return emitError(elementTypeLoc, "invalid element type for complex"),
-           nullptr;
-
-  return ComplexType::get(elementType);
-}
-
 /// Parse a function type.
 ///
 ///   function-type ::= type-list-parens `->` function-result-type
@@ -162,95 +136,6 @@ Type Parser::parseFunctionType() {
   return builder.getFunctionType(arguments, results);
 }
 
-/// Parse a memref type.
-///
-///   memref-type ::= ranked-memref-type | unranked-memref-type
-///
-///   ranked-memref-type ::= `memref` `<` dimension-list-ranked type
-///                          (`,` layout-specification)? (`,` memory-space)? `>`
-///
-///   unranked-memref-type ::= `memref` `<*x` type (`,` memory-space)? `>`
-///
-///   stride-list ::= `[` (dimension (`,` dimension)*)? `]`
-///   strided-layout ::= `offset:` dimension `,` `strides: ` stride-list
-///   layout-specification ::= semi-affine-map | strided-layout | attribute
-///   memory-space ::= integer-literal | attribute
-///
-Type Parser::parseMemRefType() {
-  SMLoc loc = getToken().getLoc();
-  consumeToken(Token::kw_memref);
-
-  if (parseToken(Token::less, "expected '<' in memref type"))
-    return nullptr;
-
-  bool isUnranked;
-  SmallVector<int64_t, 4> dimensions;
-
-  if (consumeIf(Token::star)) {
-    // This is an unranked memref type.
-    isUnranked = true;
-    if (parseXInDimensionList())
-      return nullptr;
-
-  } else {
-    isUnranked = false;
-    if (parseDimensionListRanked(dimensions))
-      return nullptr;
-  }
-
-  // Parse the element type.
-  auto typeLoc = getToken().getLoc();
-  auto elementType = parseType();
-  if (!elementType)
-    return nullptr;
-
-  // Check that memref is formed from allowed types.
-  if (!BaseMemRefType::isValidElementType(elementType))
-    return emitError(typeLoc, "invalid memref element type"), nullptr;
-
-  MemRefLayoutAttrInterface layout;
-  Attribute memorySpace;
-
-  auto parseElt = [&]() -> ParseResult {
-    // Either it is MemRefLayoutAttrInterface or memory space attribute.
-    Attribute attr = parseAttribute();
-    if (!attr)
-      return failure();
-
-    if (isa<MemRefLayoutAttrInterface>(attr)) {
-      layout = cast<MemRefLayoutAttrInterface>(attr);
-    } else if (memorySpace) {
-      return emitError("multiple memory spaces specified in memref type");
-    } else {
-      memorySpace = attr;
-      return success();
-    }
-
-    if (isUnranked)
-      return emitError("cannot have affine map for unranked memref type");
-    if (memorySpace)
-      return emitError("expected memory space to be last in memref type");
-
-    return success();
-  };
-
-  // Parse a list of mappings and address space if present.
-  if (!consumeIf(Token::greater)) {
-    // Parse comma separated list of affine maps, followed by memory space.
-    if (parseToken(Token::comma, "expected ',' or '>' in memref type") ||
-        parseCommaSeparatedListUntil(Token::greater, parseElt,
-                                     /*allowEmptyList=*/false)) {
-      return nullptr;
-    }
-  }
-
-  if (isUnranked)
-    return getChecked<UnrankedMemRefType>(loc, elementType, memorySpace);
-
-  return getChecked<MemRefType>(loc, dimensions, elementType, layout,
-                                memorySpace);
-}
-
 /// Parse any type except the function type.
 ///
 ///   non-function-type ::= integer-type
@@ -272,14 +157,12 @@ Type Parser::parseNonFunctionType() {
   switch (getToken().getKind()) {
   default:
     return (emitWrongTokenError("expected non-function type"), nullptr);
-  case Token::kw_memref:
-    return parseMemRefType();
   case Token::kw_tensor:
-    return parseTensorType();
+  case Token::kw_memref:
   case Token::kw_complex:
-    return parseComplexType();
   case Token::kw_tuple:
-    return parseTupleType();
+  case Token::kw_none:
+    return parseExtendedBuiltinType();
   case Token::kw_vector:
     return parseVectorType();
   // integer-type
@@ -344,11 +227,6 @@ Type Parser::parseNonFunctionType() {
     consumeToken(Token::kw_index);
     return builder.getIndexType();
 
-  // none-type
-  case Token::kw_none:
-    consumeToken(Token::kw_none);
-    return builder.getNoneType();
-
   // extended type
   case Token::exclamation_identifier:
     return parseExtendedType();
@@ -361,89 +239,6 @@ Type Parser::parseNonFunctionType() {
   }
 }
 
-/// Parse a tensor type.
-///
-///   tensor-type ::= `tensor` `<` dimension-list type `>`
-///   dimension-list ::= dimension-list-ranked | `*x`
-///
-Type Parser::parseTensorType() {
-  consumeToken(Token::kw_tensor);
-
-  if (parseToken(Token::less, "expected '<' in tensor type"))
-    return nullptr;
-
-  bool isUnranked;
-  SmallVector<int64_t, 4> dimensions;
-
-  if (consumeIf(Token::star)) {
-    // This is an unranked tensor type.
-    isUnranked = true;
-
-    if (parseXInDimensionList())
-      return nullptr;
-
-  } else {
-    isUnranked = false;
-    if (parseDimensionListRanked(dimensions))
-      return nullptr;
-  }
-
-  // Parse the element type.
-  auto elementTypeLoc = getToken().getLoc();
-  auto elementType = parseType();
-
-  // Parse an optional encoding attribute.
-  Attribute encoding;
-  if (consumeIf(Token::comma)) {
-    auto parseResult = parseOptionalAttribute(encoding);
-    if (parseResult.has_value()) {
-      if (failed(parseResult.value()))
-        return nullptr;
-      if (auto v = dyn_cast_or_null<VerifiableTensorEncoding>(encoding)) {
-        if (failed(v.verifyEncoding(dimensions, elementType,
-                                    [&] { return emitError(); })))
-          return nullptr;
-      }
-    }
-  }
-
-  if (!elementType || parseToken(Token::greater, "expected '>' in tensor type"))
-    return nullptr;
-  if (!TensorType::isValidElementType(elementType))
-    return emitError(elementTypeLoc, "invalid tensor element type"), nullptr;
-
-  if (isUnranked) {
-    if (encoding)
-      return emitError("cannot apply encoding to unranked tensor"), nullptr;
-    return UnrankedTensorType::get(elementType);
-  }
-  return RankedTensorType::get(dimensions, elementType, encoding);
-}
-
-/// Parse a tuple type.
-///
-///   tuple-type ::= `tuple` `<` (type (`,` type)*)? `>`
-///
-Type Parser::parseTupleType() {
-  consumeToken(Token::kw_tuple);
-
-  // Parse the '<'.
-  if (parseToken(Token::less, "expected '<' in tuple type"))
-    return nullptr;
-
-  // Check for an empty tuple by directly parsing '>'.
-  if (consumeIf(Token::greater))
-    return TupleType::get(getContext());
-
-  // Parse the element types and the '>'.
-  SmallVector<Type, 4> types;
-  if (parseTypeListNoParens(types) ||
-      parseToken(Token::greater, "expected '>' in tuple type"))
-    return nullptr;
-
-  return TupleType::get(getContext(), types);
-}
-
 /// Parse a vector type.
 ///
 /// vector-type ::= `vector` `<` vector-dim-list vector-element-type `>`
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 6b8b7473bf0f8..0679d4135048a 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2132,6 +2132,13 @@ static bool isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName) {
 /// Print the given dialect symbol to the stream.
 static void printDialectSymbol(raw_ostream &os, StringRef symPrefix,
                                StringRef dialectName, StringRef symString) {
+  // Treat the builtin dialect special by eliding the '<symPrefix>builtin'
+  // prefix.
+  if (dialectName == "builtin") {
+    os << symString;
+    return;
+  }
+
   os << symPrefix << dialectName;
 
   // If this symbol name is simple enough, print it directly in pretty form,
@@ -2599,64 +2606,6 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
         printType(vectorTy.getElementType());
         os << '>';
       })
-      .Case<RankedTensorType>([&](RankedTensorType tensorTy) {
-        os << "tensor<";
-        printDimensionList(tensorTy.getShape());
-        if (!tensorTy.getShape().empty())
-          os << 'x';
-        printType(tensorTy.getElementType());
-        // Only print the encoding attribute value if set.
-        if (tensorTy.getEncoding()) {
-          os << ", ";
-          printAttribute(tensorTy.getEncoding());
-        }
-        os << '>';
-      })
-      .Case<UnrankedTensorType>([&](UnrankedTensorType tensorTy) {
-        os << "tensor<*x";
-        printType(tensorTy.getElementType());
-        os << '>';
-      })
-      .Case<MemRefType>([&](MemRefType memrefTy) {
-        os << "memref<";
-        printDimensionList(memrefTy.getShape());
-        if (!memrefTy.getShape().empty())
-          os << 'x';
-        printType(memrefTy.getElementType());
-        MemRefLayoutAttrInterface layout = memrefTy.getLayout();
-        if (!llvm::isa<AffineMapAttr>(layout) || !layout.isIdentity()) {
-          os << ", ";
-          printAttribute(memrefTy.getLayout(), AttrTypeElision::May);
-        }
-        // Only print the memory space if it is the non-default one.
-        if (memrefTy.getMemorySpace()) {
-          os << ", ";
-          printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May);
-        }
-        os << '>';
-      })
-      .Case<UnrankedMemRefType>([&](UnrankedMemRefType memrefTy) {
-        os << "memref<*x";
-        printType(memrefTy.getElementType());
-        // Only print the memory space if it is the non-default one.
-        if (memrefTy.getMemorySpace()) {
-          os << ", ";
-          printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May);
-        }
-        os << '>';
-      })
-      .Case<ComplexType>([&](ComplexType complexTy) {
-        os << "complex<";
-        printType(complexTy.getElementType());
-        os << '>';
-      })
-      .Case<TupleType>([&](TupleType tupleTy) {
-        os << "tuple<";
-        interleaveComma(tupleTy.getTypes(),
-                        [&](Type type) { printType(type); });
-        os << '>';
-      })
-      .Case<NoneType>([&](Type) { os << "none"; })
       .Default([&](Type type) { return printDialectType(type); });
 }
 
@@ -2799,6 +2748,13 @@ void AsmPrinter::printAttributeWithoutType(Attribute attr) {
   impl->printAttribute(attr, Impl::AttrTypeElision::Must);
 }
 
+void AsmPrinter::printAttributeWithoutDefaultType(Attribute attr) {
+  assert(
+      impl &&
+      "expected AsmPrinter::printAttributeWithoutDefaultType to be overriden");
+  impl->printAttribute(attr, Impl::AttrTypeElision::May);
+}
+
 void AsmPrinter::printKeywordOrString(StringRef keyword) {
   assert(impl && "expected AsmPrinter::printKeywordOrString to be overriden");
   ::printKeywordOrString(keyword, impl->getStream());
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 9b8ee3d452803..e160c0ff4c33d 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -10,10 +10,13 @@
 #include "TypeDetail.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinDialect.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Dialect.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/TensorEncoding.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "llvm/ADT/APFloat.h"
@@ -25,6 +28,52 @@
 using namespace mlir;
 using namespace mlir::detail;
 
+//===----------------------------------------------------------------------===//
+// Custom printing and parsing
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseMemRe...
[truncated]

@zero9178 zero9178 force-pushed the users/zero9178/simplify-builtin-parsing branch from 5007356 to e67e980 Compare February 2, 2024 12:14
Most types and attributes in the builtin dialect are parsed and printed using special-purpose printers and parsers for that type. They also use the low-level `Printer` rather than the `AsmPrinter`, making the implementations inconsistent compared to all other dialects in MLIR.

This PR starts moving some builtin types to be parsed using the usual `print` and `parse` methods like all other MLIR dialects. This has the following advantages:
* The implementation now looks like any other dialect's types
* It is now possible to use `assemblyFormat` for builtin types and attributes
* The code can be easily moved to other dialects if desired
* Arguably better layering and less code
* As a side-effect, it is now also possible to write `!builtin.<type>` for any types moved

A future benefit would include being able to print types and attributes in stripped format as well (e.g. `<f32>` vs `complex<f32>`), just like all other dialect types and attributes. This is currently explicitly disabled as it causes a LOT of changes in IR syntax and I believe some ambiguities in the parser.

For the purpose of reviewing and incremental development, this PR only moves `tuple`, `tensor`, `none`, `memref` and `complex`. The plan is to eventually move all attributes and types where the current syntax can be implemented within the dialect.

For backwards compatibility with the existing syntax, the builtin dialect is special-cased in the printer where the `builtin.` prefix is omitted.
@@ -25,7 +25,8 @@ include "mlir/IR/BuiltinTypeInterfaces.td"
// Base class for Builtin dialect types.
class Builtin_Type<string name, string typeMnemonic, list<Trait> traits = [],
string baseCppClass = "::mlir::Type">
: TypeDef<Builtin_Dialect, name, traits, baseCppClass> {
: TypeDef<Builtin_Dialect, name, !listconcat(traits, [PrintTypeQualified]),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite follow why PrintTypeQualified here? Do you have a test which shows that?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The best example as to why this is needed is probably complex. Not using PrintTypeQualified means changing the syntax of ALL operations in the complex dialect as they all suddenly use the unqualified syntax.
Additionally, parseCustomTypeWithFallback currently does not handle the builtin types correctly if written verbose (e.g. complex<f32>) since the logic currently assumes that verbose types always lead with the dialect name (!builtin).
I tried a bit what logic changes would be required and did not find an obvious answer. It therefore also breaks existing input parsing.
So to summarize:

  • To stay compatible with existing syntax and tests and avoid creating massive churn
  • To avoid incompatiblity with the current implementation of stripped parsing and printing

If you like I can take a closer look at fixing parseCustomTypeWithFallback here, but thought I'd rather make that a future PR to keep this one backwards compatible.

@zero9178
Copy link
Member Author

ping

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:ods mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants