Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR] print/parse resource handle key quoted and escaped #119746

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

sorenlassen
Copy link
Contributor

resource keys have the problem that you can’t parse them from mlir assembly if they have special or non-printable characters, but nothing prevents you from specifying such a key when you create e.g. a DenseResourceElementsAttr, and it works fine in other ways, including bytecode emission and parsing

this PR solves the parsing by quoting and escaping keys with special or non-printable characters in mlir assembly, in the same way as symbols, e.g.:

module attributes {
  fst = dense_resource<resource_fst> : tensor<2xf16>,
  snd = dense_resource<"resource\09snd"> : tensor<2xf16>
} {}

{-#
  dialect_resources: {
    builtin: {
      resource_fst: "0x0200000001000200",
      "resource\09snd": "0x0200000008000900"
    }
  }
#-}

by not quoting keys without special or non-printable characters, the change is effectively backwards compatible

the change is tested by:

  1. adding a test with a dense resource handle key with special characters to dense-resource-elements-attr.mlir
  2. adding special and unprintable characters to some resource keys in the existing lit tests pretty-resources-print.mlir and mlir/test/Bytecode/resources.mlir

Copy link

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Dec 12, 2024
@llvmbot
Copy link
Member

llvmbot commented Dec 12, 2024

@llvm/pr-subscribers-mlir-core

Author: Soren Lassen (sorenlassen)

Changes

resource keys have the problem that you can’t parse them from mlir assembly if they have special or non-printable characters, but nothing prevents you from specifying such a key when you create e.g. a DenseResourceElementsAttr, and it works fine in other ways, including bytecode emission and parsing

this PR solves the parsing by quoting and escaping keys with special or non-printable characters in mlir assembly, in the same way as symbols, e.g.:

module attributes {
  fst = dense_resource&lt;resource_fst&gt; : tensor&lt;2xf16&gt;,
  snd = dense_resource&lt;"resource\09snd"&gt; : tensor&lt;2xf16&gt;
} {}

{-#
  dialect_resources: {
    builtin: {
      resource_fst: "0x0200000001000200",
      "resource\09snd": "0x0200000008000900"
    }
  }
#-}

by not quoting keys without special or non-printable characters, the change is effectively backwards compatible

the change is tested by:

  1. adding a test with a dense resource handle key with special characters to dense-resource-elements-attr.mlir
  2. adding special and unprintable characters to some resource keys in the existing lit tests pretty-resources-print.mlir and mlir/test/Bytecode/resources.mlir

Full diff: https://github.com/llvm/llvm-project/pull/119746.diff

8 Files Affected:

  • (modified) mlir/include/mlir/IR/OpImplementation.h (+2-1)
  • (modified) mlir/lib/AsmParser/AsmParserImpl.h (+1-1)
  • (modified) mlir/lib/AsmParser/Parser.cpp (+26-9)
  • (modified) mlir/lib/AsmParser/Parser.h (+4-1)
  • (modified) mlir/lib/IR/AsmPrinter.cpp (+31-31)
  • (modified) mlir/test/Bytecode/resources.mlir (+4-4)
  • (modified) mlir/test/IR/dense-resource-elements-attr.mlir (+15)
  • (modified) mlir/test/IR/pretty-resources-print.mlir (+3-3)
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 6c1ff4d0e5e6b9..b4506d58386ec8 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -202,7 +202,8 @@ class AsmPrinter {
   /// special or non-printable characters in it.
   virtual void printSymbolName(StringRef symbolRef);
 
-  /// Print a handle to the given dialect resource.
+  /// Print a handle to the given dialect resource. The handle key is quoted and
+  /// escaped if it has any special or non-printable characters in it.
   virtual void printResourceHandle(const AsmDialectResourceHandle &resource);
 
   /// Print an optional arrow followed by a type list.
diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h
index d5b72d63813a4e..9ef7592b19605f 100644
--- a/mlir/lib/AsmParser/AsmParserImpl.h
+++ b/mlir/lib/AsmParser/AsmParserImpl.h
@@ -514,7 +514,7 @@ class AsmParserImpl : public BaseT {
       return parser.emitError() << "dialect '" << dialect->getNamespace()
                                 << "' does not expect resource handles";
     }
-    StringRef resourceName;
+    std::string resourceName;
     return parser.parseResourceHandle(interface, resourceName);
   }
 
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index e3db248164672c..2a03659142a5ee 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -412,15 +412,32 @@ ParseResult Parser::parseOptionalKeyword(StringRef *keyword) {
   return success();
 }
 
+ParseResult Parser::parseOptionalKeywordOrString(std::string *result) {
+  StringRef keyword;
+  if (succeeded(parseOptionalKeyword(&keyword))) {
+    *result = keyword.str();
+    return success();
+  }
+
+  // Parse a quoted string token if present.
+  if (!getToken().is(Token::string))
+    return failure();
+
+  if (result)
+    *result = getToken().getStringValue();
+  consumeToken();
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Resource Parsing
 
 FailureOr<AsmDialectResourceHandle>
 Parser::parseResourceHandle(const OpAsmDialectInterface *dialect,
-                            StringRef &name) {
+                            std::string &name) {
   assert(dialect && "expected valid dialect interface");
   SMLoc nameLoc = getToken().getLoc();
-  if (failed(parseOptionalKeyword(&name)))
+  if (failed(parseOptionalKeywordOrString(&name)))
     return emitError("expected identifier key for 'resource' entry");
   auto &resources = getState().symbols.dialectResources;
 
@@ -451,7 +468,7 @@ Parser::parseResourceHandle(Dialect *dialect) {
     return emitError() << "dialect '" << dialect->getNamespace()
                        << "' does not expect resource handles";
   }
-  StringRef resourceName;
+  std::string resourceName;
   return parseResourceHandle(interface, resourceName);
 }
 
@@ -2530,8 +2547,8 @@ class TopLevelOperationParser : public Parser {
 /// textual format.
 class ParsedResourceEntry : public AsmParsedResourceEntry {
 public:
-  ParsedResourceEntry(StringRef key, SMLoc keyLoc, Token value, Parser &p)
-      : key(key), keyLoc(keyLoc), value(value), p(p) {}
+  ParsedResourceEntry(std::string key, SMLoc keyLoc, Token value, Parser &p)
+      : key(std::move(key)), keyLoc(keyLoc), value(value), p(p) {}
   ~ParsedResourceEntry() override = default;
 
   StringRef getKey() const final { return key; }
@@ -2607,7 +2624,7 @@ class ParsedResourceEntry : public AsmParsedResourceEntry {
   }
 
 private:
-  StringRef key;
+  std::string key;
   SMLoc keyLoc;
   Token value;
   Parser &p;
@@ -2736,7 +2753,7 @@ ParseResult TopLevelOperationParser::parseDialectResourceFileMetadata() {
     return parseCommaSeparatedListUntil(Token::r_brace, [&]() -> ParseResult {
       // Parse the name of the resource entry.
       SMLoc keyLoc = getToken().getLoc();
-      StringRef key;
+      std::string key;
       if (failed(parseResourceHandle(handler, key)) ||
           parseToken(Token::colon, "expected ':'"))
         return failure();
@@ -2763,8 +2780,8 @@ ParseResult TopLevelOperationParser::parseExternalResourceFileMetadata() {
     return parseCommaSeparatedListUntil(Token::r_brace, [&]() -> ParseResult {
       // Parse the name of the resource entry.
       SMLoc keyLoc = getToken().getLoc();
-      StringRef key;
-      if (failed(parseOptionalKeyword(&key)))
+      std::string key;
+      if (failed(parseOptionalKeywordOrString(&key)))
         return emitError(
             "expected identifier key for 'external_resources' entry");
       if (parseToken(Token::colon, "expected ':'"))
diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h
index 37670bd789fecb..86d572b882ee0e 100644
--- a/mlir/lib/AsmParser/Parser.h
+++ b/mlir/lib/AsmParser/Parser.h
@@ -171,13 +171,16 @@ class Parser {
   /// Parse a keyword, if present, into 'keyword'.
   ParseResult parseOptionalKeyword(StringRef *keyword);
 
+  /// Parse an optional keyword or string and set instance into 'result'.`
+  ParseResult parseOptionalKeywordOrString(std::string *result);
+
   //===--------------------------------------------------------------------===//
   // Resource Parsing
   //===--------------------------------------------------------------------===//
 
   /// Parse a handle to a dialect resource within the assembly format.
   FailureOr<AsmDialectResourceHandle>
-  parseResourceHandle(const OpAsmDialectInterface *dialect, StringRef &name);
+  parseResourceHandle(const OpAsmDialectInterface *dialect, std::string &name);
   FailureOr<AsmDialectResourceHandle> parseResourceHandle(Dialect *dialect);
 
   //===--------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 61b90bc9b0a7bb..7c2c7f0875e2ad 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2146,13 +2146,6 @@ void AsmPrinter::Impl::printLocation(LocationAttr loc, bool allowAlias) {
   os << ')';
 }
 
-void AsmPrinter::Impl::printResourceHandle(
-    const AsmDialectResourceHandle &resource) {
-  auto *interface = cast<OpAsmDialectInterface>(resource.getDialect());
-  os << interface->getResourceKey(resource);
-  state.getDialectResources()[resource.getDialect()].insert(resource);
-}
-
 /// Returns true if the given dialect symbol data is simple enough to print in
 /// the pretty form. This is essentially when the symbol takes the form:
 ///   identifier (`<` body `>`)?
@@ -2237,6 +2230,13 @@ static void printElidedElementsAttr(raw_ostream &os) {
   os << R"(dense_resource<__elided__>)";
 }
 
+void AsmPrinter::Impl::printResourceHandle(
+    const AsmDialectResourceHandle &resource) {
+  auto *interface = cast<OpAsmDialectInterface>(resource.getDialect());
+  ::printKeywordOrString(interface->getResourceKey(resource), os);
+  state.getDialectResources()[resource.getDialect()].insert(resource);
+}
+
 LogicalResult AsmPrinter::Impl::printAlias(Attribute attr) {
   return state.getAliasState().getAlias(attr, os);
 }
@@ -3331,41 +3331,41 @@ void OperationPrinter::printResourceFileMetadata(
     auto printFn = [&](StringRef key, ResourceBuilder::ValueFn valueFn) {
       checkAddMetadataDict();
 
-      auto printFormatting = [&]() {
-        // Emit the top-level resource entry if we haven't yet.
-        if (!std::exchange(hadResource, true)) {
-          if (needResourceComma)
-            os << "," << newLine;
-          os << "  " << dictName << "_resources: {" << newLine;
-        }
-        // Emit the parent resource entry if we haven't yet.
-        if (!std::exchange(hadEntry, true)) {
-          if (needEntryComma)
-            os << "," << newLine;
-          os << "    " << name << ": {" << newLine;
-        } else {
-          os << "," << newLine;
-        }
-      };
-
+      std::string resourceStr;
+      auto printResourceStr = [&](raw_ostream &os) { os << resourceStr; };
       std::optional<uint64_t> charLimit =
           printerFlags.getLargeResourceStringLimit();
       if (charLimit.has_value()) {
-        std::string resourceStr;
         llvm::raw_string_ostream ss(resourceStr);
         valueFn(ss);
 
-        // Only print entry if it's string is small enough
+        // Only print entry if its string is small enough.
         if (resourceStr.size() > charLimit.value())
           return;
 
-        printFormatting();
-        os << "      " << key << ": " << resourceStr;
+        // Don't recompute resourceStr when valueFn is called below.
+        valueFn = printResourceStr;
+      }
+
+      // Emit the top-level resource entry if we haven't yet.
+      if (!std::exchange(hadResource, true)) {
+        if (needResourceComma)
+          os << "," << newLine;
+        os << "  " << dictName << "_resources: {" << newLine;
+      }
+      // Emit the parent resource entry if we haven't yet.
+      if (!std::exchange(hadEntry, true)) {
+        if (needEntryComma)
+          os << "," << newLine;
+        os << "    " << name << ": {" << newLine;
       } else {
-        printFormatting();
-        os << "      " << key << ": ";
-        valueFn(os);
+        os << "," << newLine;
       }
+      os << "      ";
+      ::printKeywordOrString(key, os);
+      os << ": ";
+      // Call printResourceStr or original valueFn, depending on charLimit.
+      valueFn(os);
     };
     ResourceBuilder entryBuilder(printFn);
     provider.buildResources(op, providerArgs..., entryBuilder);
diff --git a/mlir/test/Bytecode/resources.mlir b/mlir/test/Bytecode/resources.mlir
index 33ed01d20fa0c5..3ef220e890042c 100644
--- a/mlir/test/Bytecode/resources.mlir
+++ b/mlir/test/Bytecode/resources.mlir
@@ -4,21 +4,21 @@
 module @TestDialectResources attributes {
   // CHECK: bytecode.test = dense_resource<decl_resource> : tensor<2xui32>
   // CHECK: bytecode.test2 = dense_resource<resource> : tensor<4xf64>
-  // CHECK: bytecode.test3 = dense_resource<resource_2> : tensor<4xf64>
+  // CHECK: bytecode.test3 = dense_resource<"resource\09two"> : tensor<4xf64>
   bytecode.test = dense_resource<decl_resource> : tensor<2xui32>,
   bytecode.test2 = dense_resource<resource> : tensor<4xf64>,
-  bytecode.test3 = dense_resource<resource_2> : tensor<4xf64>
+  bytecode.test3 = dense_resource<"resource\09two"> : tensor<4xf64>
 } {}
 
 // CHECK: builtin: {
 // CHECK-NEXT: resource: "0x08000000010000000000000002000000000000000300000000000000"
-// CHECK-NEXT: resource_2: "0x08000000010000000000000002000000000000000300000000000000"
+// CHECK-NEXT: "resource\09two": "0x08000000010000000000000002000000000000000300000000000000"
 
 {-#
   dialect_resources: {
     builtin: {
       resource: "0x08000000010000000000000002000000000000000300000000000000",
-      resource_2: "0x08000000010000000000000002000000000000000300000000000000"
+      "resource\09two": "0x08000000010000000000000002000000000000000300000000000000"
     }
   }
 #-}
diff --git a/mlir/test/IR/dense-resource-elements-attr.mlir b/mlir/test/IR/dense-resource-elements-attr.mlir
index adba97994ff60f..44cefc3aa1616d 100644
--- a/mlir/test/IR/dense-resource-elements-attr.mlir
+++ b/mlir/test/IR/dense-resource-elements-attr.mlir
@@ -11,3 +11,18 @@
     }
   }
 #-}
+
+// -----
+
+// DenseResourceElementsHandle key blob\-"one" is quoted and escaped.
+// CHECK: attr = dense_resource<"blob\\-\22one\22"> : tensor<2xi16>
+"test.user_op"() {attr = dense_resource<"blob\\-\22one\22"> : tensor<2xi16>} : () -> ()
+
+{-#
+  dialect_resources: {
+    builtin: {
+      // CHECK: "blob\\-\22one\22": "0x0200000001000200"
+      "blob\\-\22one\22": "0x0200000001000200"
+    }
+  }
+#-}
diff --git a/mlir/test/IR/pretty-resources-print.mlir b/mlir/test/IR/pretty-resources-print.mlir
index 625967fcb76038..297c83bbb13896 100644
--- a/mlir/test/IR/pretty-resources-print.mlir
+++ b/mlir/test/IR/pretty-resources-print.mlir
@@ -12,7 +12,7 @@
 // CHECK:      {-#
 // CHECK-NEXT:   external_resources: {
 // CHECK-NEXT:     external: {
-// CHECK-NEXT:       bool: true,
+// CHECK-NEXT:       "backslash\\tab\09": true,
 // CHECK-NEXT:       string: "\22string\22"
 // CHECK-NEXT:     },
 // CHECK-NEXT:     other_stuff: {
@@ -31,8 +31,8 @@
   external_resources: {
     external: {
       blob: "0x08000000010000000000000002000000000000000300000000000000",
-      bool: true,
-      string: "\"string\"" // with escape characters
+      "backslash\\tab\09": true, // quoted key with escape characters
+      string: "\"string\"" // string with escape characters
     },
     other_stuff: {
       bool: true

@llvmbot
Copy link
Member

llvmbot commented Dec 12, 2024

@llvm/pr-subscribers-mlir

Author: Soren Lassen (sorenlassen)

Changes

resource keys have the problem that you can’t parse them from mlir assembly if they have special or non-printable characters, but nothing prevents you from specifying such a key when you create e.g. a DenseResourceElementsAttr, and it works fine in other ways, including bytecode emission and parsing

this PR solves the parsing by quoting and escaping keys with special or non-printable characters in mlir assembly, in the same way as symbols, e.g.:

module attributes {
  fst = dense_resource&lt;resource_fst&gt; : tensor&lt;2xf16&gt;,
  snd = dense_resource&lt;"resource\09snd"&gt; : tensor&lt;2xf16&gt;
} {}

{-#
  dialect_resources: {
    builtin: {
      resource_fst: "0x0200000001000200",
      "resource\09snd": "0x0200000008000900"
    }
  }
#-}

by not quoting keys without special or non-printable characters, the change is effectively backwards compatible

the change is tested by:

  1. adding a test with a dense resource handle key with special characters to dense-resource-elements-attr.mlir
  2. adding special and unprintable characters to some resource keys in the existing lit tests pretty-resources-print.mlir and mlir/test/Bytecode/resources.mlir

Full diff: https://github.com/llvm/llvm-project/pull/119746.diff

8 Files Affected:

  • (modified) mlir/include/mlir/IR/OpImplementation.h (+2-1)
  • (modified) mlir/lib/AsmParser/AsmParserImpl.h (+1-1)
  • (modified) mlir/lib/AsmParser/Parser.cpp (+26-9)
  • (modified) mlir/lib/AsmParser/Parser.h (+4-1)
  • (modified) mlir/lib/IR/AsmPrinter.cpp (+31-31)
  • (modified) mlir/test/Bytecode/resources.mlir (+4-4)
  • (modified) mlir/test/IR/dense-resource-elements-attr.mlir (+15)
  • (modified) mlir/test/IR/pretty-resources-print.mlir (+3-3)
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 6c1ff4d0e5e6b9..b4506d58386ec8 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -202,7 +202,8 @@ class AsmPrinter {
   /// special or non-printable characters in it.
   virtual void printSymbolName(StringRef symbolRef);
 
-  /// Print a handle to the given dialect resource.
+  /// Print a handle to the given dialect resource. The handle key is quoted and
+  /// escaped if it has any special or non-printable characters in it.
   virtual void printResourceHandle(const AsmDialectResourceHandle &resource);
 
   /// Print an optional arrow followed by a type list.
diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h
index d5b72d63813a4e..9ef7592b19605f 100644
--- a/mlir/lib/AsmParser/AsmParserImpl.h
+++ b/mlir/lib/AsmParser/AsmParserImpl.h
@@ -514,7 +514,7 @@ class AsmParserImpl : public BaseT {
       return parser.emitError() << "dialect '" << dialect->getNamespace()
                                 << "' does not expect resource handles";
     }
-    StringRef resourceName;
+    std::string resourceName;
     return parser.parseResourceHandle(interface, resourceName);
   }
 
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index e3db248164672c..2a03659142a5ee 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -412,15 +412,32 @@ ParseResult Parser::parseOptionalKeyword(StringRef *keyword) {
   return success();
 }
 
+ParseResult Parser::parseOptionalKeywordOrString(std::string *result) {
+  StringRef keyword;
+  if (succeeded(parseOptionalKeyword(&keyword))) {
+    *result = keyword.str();
+    return success();
+  }
+
+  // Parse a quoted string token if present.
+  if (!getToken().is(Token::string))
+    return failure();
+
+  if (result)
+    *result = getToken().getStringValue();
+  consumeToken();
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Resource Parsing
 
 FailureOr<AsmDialectResourceHandle>
 Parser::parseResourceHandle(const OpAsmDialectInterface *dialect,
-                            StringRef &name) {
+                            std::string &name) {
   assert(dialect && "expected valid dialect interface");
   SMLoc nameLoc = getToken().getLoc();
-  if (failed(parseOptionalKeyword(&name)))
+  if (failed(parseOptionalKeywordOrString(&name)))
     return emitError("expected identifier key for 'resource' entry");
   auto &resources = getState().symbols.dialectResources;
 
@@ -451,7 +468,7 @@ Parser::parseResourceHandle(Dialect *dialect) {
     return emitError() << "dialect '" << dialect->getNamespace()
                        << "' does not expect resource handles";
   }
-  StringRef resourceName;
+  std::string resourceName;
   return parseResourceHandle(interface, resourceName);
 }
 
@@ -2530,8 +2547,8 @@ class TopLevelOperationParser : public Parser {
 /// textual format.
 class ParsedResourceEntry : public AsmParsedResourceEntry {
 public:
-  ParsedResourceEntry(StringRef key, SMLoc keyLoc, Token value, Parser &p)
-      : key(key), keyLoc(keyLoc), value(value), p(p) {}
+  ParsedResourceEntry(std::string key, SMLoc keyLoc, Token value, Parser &p)
+      : key(std::move(key)), keyLoc(keyLoc), value(value), p(p) {}
   ~ParsedResourceEntry() override = default;
 
   StringRef getKey() const final { return key; }
@@ -2607,7 +2624,7 @@ class ParsedResourceEntry : public AsmParsedResourceEntry {
   }
 
 private:
-  StringRef key;
+  std::string key;
   SMLoc keyLoc;
   Token value;
   Parser &p;
@@ -2736,7 +2753,7 @@ ParseResult TopLevelOperationParser::parseDialectResourceFileMetadata() {
     return parseCommaSeparatedListUntil(Token::r_brace, [&]() -> ParseResult {
       // Parse the name of the resource entry.
       SMLoc keyLoc = getToken().getLoc();
-      StringRef key;
+      std::string key;
       if (failed(parseResourceHandle(handler, key)) ||
           parseToken(Token::colon, "expected ':'"))
         return failure();
@@ -2763,8 +2780,8 @@ ParseResult TopLevelOperationParser::parseExternalResourceFileMetadata() {
     return parseCommaSeparatedListUntil(Token::r_brace, [&]() -> ParseResult {
       // Parse the name of the resource entry.
       SMLoc keyLoc = getToken().getLoc();
-      StringRef key;
-      if (failed(parseOptionalKeyword(&key)))
+      std::string key;
+      if (failed(parseOptionalKeywordOrString(&key)))
         return emitError(
             "expected identifier key for 'external_resources' entry");
       if (parseToken(Token::colon, "expected ':'"))
diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h
index 37670bd789fecb..86d572b882ee0e 100644
--- a/mlir/lib/AsmParser/Parser.h
+++ b/mlir/lib/AsmParser/Parser.h
@@ -171,13 +171,16 @@ class Parser {
   /// Parse a keyword, if present, into 'keyword'.
   ParseResult parseOptionalKeyword(StringRef *keyword);
 
+  /// Parse an optional keyword or string and set instance into 'result'.`
+  ParseResult parseOptionalKeywordOrString(std::string *result);
+
   //===--------------------------------------------------------------------===//
   // Resource Parsing
   //===--------------------------------------------------------------------===//
 
   /// Parse a handle to a dialect resource within the assembly format.
   FailureOr<AsmDialectResourceHandle>
-  parseResourceHandle(const OpAsmDialectInterface *dialect, StringRef &name);
+  parseResourceHandle(const OpAsmDialectInterface *dialect, std::string &name);
   FailureOr<AsmDialectResourceHandle> parseResourceHandle(Dialect *dialect);
 
   //===--------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 61b90bc9b0a7bb..7c2c7f0875e2ad 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2146,13 +2146,6 @@ void AsmPrinter::Impl::printLocation(LocationAttr loc, bool allowAlias) {
   os << ')';
 }
 
-void AsmPrinter::Impl::printResourceHandle(
-    const AsmDialectResourceHandle &resource) {
-  auto *interface = cast<OpAsmDialectInterface>(resource.getDialect());
-  os << interface->getResourceKey(resource);
-  state.getDialectResources()[resource.getDialect()].insert(resource);
-}
-
 /// Returns true if the given dialect symbol data is simple enough to print in
 /// the pretty form. This is essentially when the symbol takes the form:
 ///   identifier (`<` body `>`)?
@@ -2237,6 +2230,13 @@ static void printElidedElementsAttr(raw_ostream &os) {
   os << R"(dense_resource<__elided__>)";
 }
 
+void AsmPrinter::Impl::printResourceHandle(
+    const AsmDialectResourceHandle &resource) {
+  auto *interface = cast<OpAsmDialectInterface>(resource.getDialect());
+  ::printKeywordOrString(interface->getResourceKey(resource), os);
+  state.getDialectResources()[resource.getDialect()].insert(resource);
+}
+
 LogicalResult AsmPrinter::Impl::printAlias(Attribute attr) {
   return state.getAliasState().getAlias(attr, os);
 }
@@ -3331,41 +3331,41 @@ void OperationPrinter::printResourceFileMetadata(
     auto printFn = [&](StringRef key, ResourceBuilder::ValueFn valueFn) {
       checkAddMetadataDict();
 
-      auto printFormatting = [&]() {
-        // Emit the top-level resource entry if we haven't yet.
-        if (!std::exchange(hadResource, true)) {
-          if (needResourceComma)
-            os << "," << newLine;
-          os << "  " << dictName << "_resources: {" << newLine;
-        }
-        // Emit the parent resource entry if we haven't yet.
-        if (!std::exchange(hadEntry, true)) {
-          if (needEntryComma)
-            os << "," << newLine;
-          os << "    " << name << ": {" << newLine;
-        } else {
-          os << "," << newLine;
-        }
-      };
-
+      std::string resourceStr;
+      auto printResourceStr = [&](raw_ostream &os) { os << resourceStr; };
       std::optional<uint64_t> charLimit =
           printerFlags.getLargeResourceStringLimit();
       if (charLimit.has_value()) {
-        std::string resourceStr;
         llvm::raw_string_ostream ss(resourceStr);
         valueFn(ss);
 
-        // Only print entry if it's string is small enough
+        // Only print entry if its string is small enough.
         if (resourceStr.size() > charLimit.value())
           return;
 
-        printFormatting();
-        os << "      " << key << ": " << resourceStr;
+        // Don't recompute resourceStr when valueFn is called below.
+        valueFn = printResourceStr;
+      }
+
+      // Emit the top-level resource entry if we haven't yet.
+      if (!std::exchange(hadResource, true)) {
+        if (needResourceComma)
+          os << "," << newLine;
+        os << "  " << dictName << "_resources: {" << newLine;
+      }
+      // Emit the parent resource entry if we haven't yet.
+      if (!std::exchange(hadEntry, true)) {
+        if (needEntryComma)
+          os << "," << newLine;
+        os << "    " << name << ": {" << newLine;
       } else {
-        printFormatting();
-        os << "      " << key << ": ";
-        valueFn(os);
+        os << "," << newLine;
       }
+      os << "      ";
+      ::printKeywordOrString(key, os);
+      os << ": ";
+      // Call printResourceStr or original valueFn, depending on charLimit.
+      valueFn(os);
     };
     ResourceBuilder entryBuilder(printFn);
     provider.buildResources(op, providerArgs..., entryBuilder);
diff --git a/mlir/test/Bytecode/resources.mlir b/mlir/test/Bytecode/resources.mlir
index 33ed01d20fa0c5..3ef220e890042c 100644
--- a/mlir/test/Bytecode/resources.mlir
+++ b/mlir/test/Bytecode/resources.mlir
@@ -4,21 +4,21 @@
 module @TestDialectResources attributes {
   // CHECK: bytecode.test = dense_resource<decl_resource> : tensor<2xui32>
   // CHECK: bytecode.test2 = dense_resource<resource> : tensor<4xf64>
-  // CHECK: bytecode.test3 = dense_resource<resource_2> : tensor<4xf64>
+  // CHECK: bytecode.test3 = dense_resource<"resource\09two"> : tensor<4xf64>
   bytecode.test = dense_resource<decl_resource> : tensor<2xui32>,
   bytecode.test2 = dense_resource<resource> : tensor<4xf64>,
-  bytecode.test3 = dense_resource<resource_2> : tensor<4xf64>
+  bytecode.test3 = dense_resource<"resource\09two"> : tensor<4xf64>
 } {}
 
 // CHECK: builtin: {
 // CHECK-NEXT: resource: "0x08000000010000000000000002000000000000000300000000000000"
-// CHECK-NEXT: resource_2: "0x08000000010000000000000002000000000000000300000000000000"
+// CHECK-NEXT: "resource\09two": "0x08000000010000000000000002000000000000000300000000000000"
 
 {-#
   dialect_resources: {
     builtin: {
       resource: "0x08000000010000000000000002000000000000000300000000000000",
-      resource_2: "0x08000000010000000000000002000000000000000300000000000000"
+      "resource\09two": "0x08000000010000000000000002000000000000000300000000000000"
     }
   }
 #-}
diff --git a/mlir/test/IR/dense-resource-elements-attr.mlir b/mlir/test/IR/dense-resource-elements-attr.mlir
index adba97994ff60f..44cefc3aa1616d 100644
--- a/mlir/test/IR/dense-resource-elements-attr.mlir
+++ b/mlir/test/IR/dense-resource-elements-attr.mlir
@@ -11,3 +11,18 @@
     }
   }
 #-}
+
+// -----
+
+// DenseResourceElementsHandle key blob\-"one" is quoted and escaped.
+// CHECK: attr = dense_resource<"blob\\-\22one\22"> : tensor<2xi16>
+"test.user_op"() {attr = dense_resource<"blob\\-\22one\22"> : tensor<2xi16>} : () -> ()
+
+{-#
+  dialect_resources: {
+    builtin: {
+      // CHECK: "blob\\-\22one\22": "0x0200000001000200"
+      "blob\\-\22one\22": "0x0200000001000200"
+    }
+  }
+#-}
diff --git a/mlir/test/IR/pretty-resources-print.mlir b/mlir/test/IR/pretty-resources-print.mlir
index 625967fcb76038..297c83bbb13896 100644
--- a/mlir/test/IR/pretty-resources-print.mlir
+++ b/mlir/test/IR/pretty-resources-print.mlir
@@ -12,7 +12,7 @@
 // CHECK:      {-#
 // CHECK-NEXT:   external_resources: {
 // CHECK-NEXT:     external: {
-// CHECK-NEXT:       bool: true,
+// CHECK-NEXT:       "backslash\\tab\09": true,
 // CHECK-NEXT:       string: "\22string\22"
 // CHECK-NEXT:     },
 // CHECK-NEXT:     other_stuff: {
@@ -31,8 +31,8 @@
   external_resources: {
     external: {
       blob: "0x08000000010000000000000002000000000000000300000000000000",
-      bool: true,
-      string: "\"string\"" // with escape characters
+      "backslash\\tab\09": true, // quoted key with escape characters
+      string: "\"string\"" // string with escape characters
     },
     other_stuff: {
       bool: true

printFormatting();
os << " " << key << ": " << resourceStr;
// Don't recompute resourceStr when valueFn is called below.
valueFn = printResourceStr;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the old implementation of printFn used the printFormatting closure to reduce code duplication across the two charLimit conditional branches, but there was still some code duplication to print the key which became worse with the new logic

therefore I refactored the implementation to instead “reprogram” valueFn in the case where resourceStr is materialized to check charLimit

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants