Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir-tblgen] Migrate away from PointerUnion::{is,get} (NFC) #119994

Merged

Conversation

kazutakahirata
Copy link
Contributor

Note that PointerUnion::{is,get} have been soft deprecated in
PointerUnion.h:

// FIXME: Replace the uses of is(), get() and dyn_cast() with
// isa, cast and the llvm::dyn_cast

I'm not touching PointerUnion::dyn_cast for now because it's a bit
complicated; we could blindly migrate it to dyn_cast_if_present, but
we should probably use dyn_cast when the operand is known to be
non-null.

Note that PointerUnion::{is,get} have been soft deprecated in
PointerUnion.h:

  // FIXME: Replace the uses of is(), get() and dyn_cast() with
  //        isa<T>, cast<T> and the llvm::dyn_cast<T>

I'm not touching PointerUnion::dyn_cast for now because it's a bit
complicated; we could blindly migrate it to dyn_cast_if_present, but
we should probably use dyn_cast when the operand is known to be
non-null.
@llvmbot
Copy link
Member

llvmbot commented Dec 15, 2024

@llvm/pr-subscribers-mlir-spirv

@llvm/pr-subscribers-mlir-core

Author: Kazu Hirata (kazutakahirata)

Changes

Note that PointerUnion::{is,get} have been soft deprecated in
PointerUnion.h:

// FIXME: Replace the uses of is(), get() and dyn_cast() with
// isa<T>, cast<T> and the llvm::dyn_cast<T>

I'm not touching PointerUnion::dyn_cast for now because it's a bit
complicated; we could blindly migrate it to dyn_cast_if_present, but
we should probably use dyn_cast when the operand is known to be
non-null.


Full diff: https://github.com/llvm/llvm-project/pull/119994.diff

4 Files Affected:

  • (modified) mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp (+6-6)
  • (modified) mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp (+1-1)
  • (modified) mlir/tools/mlir-tblgen/RewriterGen.cpp (+9-9)
  • (modified) mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp (+7-5)
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 59ae7777da5675..30a2d4da573dc5 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -3162,7 +3162,7 @@ void OpEmitter::buildParamList(SmallVectorImpl<MethodParameter> &paramList,
                              isOptional);
       continue;
     }
-    const NamedAttribute &namedAttr = *arg.get<NamedAttribute *>();
+    const NamedAttribute &namedAttr = *cast<NamedAttribute *>(arg);
     const Attribute &attr = namedAttr.attr;
 
     // Inferred attributes don't need to be added to the param list.
@@ -3499,14 +3499,14 @@ void OpEmitter::genSideEffectInterfaceMethods() {
   /// Attributes and Operands.
   for (unsigned i = 0, operandIt = 0, e = op.getNumArgs(); i != e; ++i) {
     Argument arg = op.getArg(i);
-    if (arg.is<NamedTypeConstraint *>()) {
+    if (isa<NamedTypeConstraint *>(arg)) {
       resolveDecorators(op.getArgDecorators(i), operandIt, EffectKind::Operand);
       ++operandIt;
       continue;
     }
-    if (arg.is<NamedProperty *>())
+    if (isa<NamedProperty *>(arg))
       continue;
-    const NamedAttribute *attr = arg.get<NamedAttribute *>();
+    const NamedAttribute *attr = cast<NamedAttribute *>(arg);
     if (attr->attr.getBaseAttr().isSymbolRefAttr())
       resolveDecorators(op.getArgDecorators(i), i, EffectKind::Symbol);
   }
@@ -3547,7 +3547,7 @@ void OpEmitter::genSideEffectInterfaceMethods() {
                     .str();
       } else if (location.kind == EffectKind::Symbol) {
         // A symbol reference requires adding the proper attribute.
-        const auto *attr = op.getArg(location.index).get<NamedAttribute *>();
+        const auto *attr = cast<NamedAttribute *>(op.getArg(location.index));
         std::string argName = op.getGetterName(attr->name);
         if (attr->attr.isOptional()) {
           body << "  if (auto symbolRef = " << argName << "Attr())\n  "
@@ -3648,7 +3648,7 @@ void OpEmitter::genTypeInterfaceMethods() {
           // If this is an attribute, index into the attribute dictionary.
         } else {
           auto *attr =
-              op.getArg(arg.operandOrAttributeIndex()).get<NamedAttribute *>();
+              cast<NamedAttribute *>(op.getArg(arg.operandOrAttributeIndex()));
           body << "  ::mlir::TypedAttr odsInferredTypeAttr" << inferredTypeIdx
                << " = ";
           if (op.getDialect().usePropertiesForAttributes()) {
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 5019b69d91127e..300d9977f03994 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -628,7 +628,7 @@ static void populateBuilderArgs(const Operator &op,
       name = formatv("_gen_arg_{0}", i);
     name = sanitizeName(name);
     builderArgs.push_back(name);
-    if (!op.getArg(i).is<NamedAttribute *>())
+    if (!isa<NamedAttribute *>(op.getArg(i)))
       operandNames.push_back(name);
   }
 }
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index e2764f8493cd8a..a041c4d3277798 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -655,7 +655,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
     }
 
     // Next handle DAG leaf: operand or attribute
-    if (opArg.is<NamedTypeConstraint *>()) {
+    if (isa<NamedTypeConstraint *>(opArg)) {
       auto operandName =
           formatv("{0}.getODSOperands({1})", castedName, nextOperand);
       emitOperandMatch(tree, castedName, operandName.str(), opArgIdx,
@@ -663,7 +663,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
                        /*argName=*/tree.getArgName(i), opArgIdx,
                        /*variadicSubIndex=*/std::nullopt);
       ++nextOperand;
-    } else if (opArg.is<NamedAttribute *>()) {
+    } else if (isa<NamedAttribute *>(opArg)) {
       emitAttributeMatch(tree, opName, opArgIdx, depth);
     } else {
       PrintFatalError(loc, "unhandled case when matching op");
@@ -680,7 +680,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
                                       int argIndex,
                                       std::optional<int> variadicSubIndex) {
   Operator &op = tree.getDialectOp(opMap);
-  auto *operand = op.getArg(operandIndex).get<NamedTypeConstraint *>();
+  auto *operand = cast<NamedTypeConstraint *>(op.getArg(operandIndex));
 
   // If a constraint is specified, we need to generate C++ statements to
   // check the constraint.
@@ -770,7 +770,7 @@ void PatternEmitter::emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree,
       // need to queue the operation only if the matching success. Thus we emit
       // the code at the end.
       tblgenOps << formatv("tblgen_ops.push_back({0});\n", argName);
-    } else if (op.getArg(argIndex).is<NamedTypeConstraint *>()) {
+    } else if (isa<NamedTypeConstraint *>(op.getArg(argIndex))) {
       emitOperandMatch(tree, opName, /*operandName=*/formatv("v{0}", i).str(),
                        operandIndex,
                        /*operandMatcher=*/eitherArgTree.getArgAsLeaf(i),
@@ -851,7 +851,7 @@ void PatternEmitter::emitVariadicOperandMatch(DagNode tree,
       os << formatv("tblgen_ops.push_back({0});\n", argName);
 
       os.unindent() << "}\n";
-    } else if (op.getArg(argIndex).is<NamedTypeConstraint *>()) {
+    } else if (isa<NamedTypeConstraint *>(op.getArg(argIndex))) {
       auto operandName = formatv("variadic_operand_range.slice({0}, 1)", i);
       emitOperandMatch(tree, opName, operandName.str(), operandIndex,
                        /*operandMatcher=*/variadicArgTree.getArgAsLeaf(i),
@@ -867,7 +867,7 @@ void PatternEmitter::emitVariadicOperandMatch(DagNode tree,
 void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
                                         int argIndex, int depth) {
   Operator &op = tree.getDialectOp(opMap);
-  auto *namedAttr = op.getArg(argIndex).get<NamedAttribute *>();
+  auto *namedAttr = cast<NamedAttribute *>(op.getArg(argIndex));
   const auto &attr = namedAttr->attr;
 
   os << "{\n";
@@ -1775,7 +1775,7 @@ void PatternEmitter::supplyValuesForOpArgs(
       auto patArgName = node.getArgName(argIndex);
       if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
         // TODO: Refactor out into map to avoid recomputing these.
-        if (!opArg.is<NamedAttribute *>())
+        if (!isa<NamedAttribute *>(opArg))
           PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex));
         if (!patArgName.empty())
           os << "/*" << patArgName << "=*/";
@@ -1805,7 +1805,7 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
   bool hasOperandSegmentSizes = false;
   std::vector<std::string> sizes;
   for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
-    if (resultOp.getArg(argIndex).is<NamedAttribute *>()) {
+    if (isa<NamedAttribute *>(resultOp.getArg(argIndex))) {
       // The argument in the op definition.
       auto opArgName = resultOp.getArgName(argIndex);
       hasOperandSegmentSizes =
@@ -1826,7 +1826,7 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
     }
 
     const auto *operand =
-        resultOp.getArg(argIndex).get<NamedTypeConstraint *>();
+        cast<NamedTypeConstraint *>(resultOp.getArg(argIndex));
     std::string varName;
     if (operand->isVariadic()) {
       ++numVariadic;
diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index 75286231f5902f..75b8829be4da59 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -33,7 +33,9 @@
 #include <optional>
 
 using llvm::ArrayRef;
+using llvm::cast;
 using llvm::formatv;
+using llvm::isa;
 using llvm::raw_ostream;
 using llvm::raw_string_ostream;
 using llvm::Record;
@@ -607,11 +609,11 @@ static void emitArgumentSerialization(const Operator &op, ArrayRef<SMLoc> loc,
   bool areOperandsAheadOfAttrs = true;
   // Find the first attribute.
   const Argument *it = llvm::find_if(op.getArgs(), [](const Argument &arg) {
-    return arg.is<NamedAttribute *>();
+    return isa<NamedAttribute *>(arg);
   });
   // Check whether all following arguments are attributes.
   for (const Argument *ie = op.arg_end(); it != ie; ++it) {
-    if (!it->is<NamedAttribute *>()) {
+    if (!isa<NamedAttribute *>(*it)) {
       areOperandsAheadOfAttrs = false;
       break;
     }
@@ -642,7 +644,7 @@ static void emitArgumentSerialization(const Operator &op, ArrayRef<SMLoc> loc,
   for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) {
     auto argument = op.getArg(i);
     os << tabs << "{\n";
-    if (argument.is<NamedTypeConstraint *>()) {
+    if (isa<NamedTypeConstraint *>(argument)) {
       os << tabs
          << formatv("  for (auto arg : {0}.getODSOperands({1})) {{\n", opVar,
                     operandNum);
@@ -657,7 +659,7 @@ static void emitArgumentSerialization(const Operator &op, ArrayRef<SMLoc> loc,
       os << "    }\n";
       operandNum++;
     } else {
-      NamedAttribute *attr = argument.get<NamedAttribute *>();
+      NamedAttribute *attr = cast<NamedAttribute *>(argument);
       auto newtabs = tabs.str() + "  ";
       emitAttributeSerialization(
           (attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr),
@@ -962,7 +964,7 @@ static void emitOperandDeserialization(const Operator &op, ArrayRef<SMLoc> loc,
       os << tabs << "}\n";
     } else {
       os << tabs << formatv("if ({0} < {1}.size()) {{\n", wordIndex, words);
-      auto *attr = argument.get<NamedAttribute *>();
+      auto *attr = cast<NamedAttribute *>(argument);
       auto newtabs = tabs.str() + "  ";
       emitAttributeDeserialization(
           (attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr),

@kazutakahirata kazutakahirata merged commit b0746c6 into llvm:main Dec 15, 2024
12 checks passed
@kazutakahirata kazutakahirata deleted the cleanup_001_PointerUnion_mlir_tblgen branch December 15, 2024 17:18
raikonenfnu added a commit to iree-org/llvm-project that referenced this pull request Dec 17, 2024
raikonenfnu added a commit to iree-org/llvm-project that referenced this pull request Dec 17, 2024
MaheshRavishankar pushed a commit to iree-org/llvm-project that referenced this pull request Dec 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:spirv mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants