Skip to content

Conversation

makslevental
Copy link
Contributor

@makslevental makslevental commented Sep 14, 2025

This PR adds type hints for accessors in the generated builders.

@makslevental makslevental force-pushed the users/makslevental/fix-accessors branch from b5f93d7 to d645886 Compare September 15, 2025 00:13
@makslevental makslevental marked this pull request as ready for review September 15, 2025 00:21
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Sep 15, 2025
@llvmbot
Copy link
Member

llvmbot commented Sep 15, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Maksim Levental (makslevental)

Changes

This PR adds type hints for accessors in the generated builders.


Full diff: https://github.com/llvm/llvm-project/pull/158455.diff

3 Files Affected:

  • (modified) mlir/test/mlir-tblgen/op-python-bindings.td (+26-26)
  • (modified) mlir/test/python/dialects/python_test.py (+13-4)
  • (modified) mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp (+52-19)
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 3ec69c33b4bb9..b2415d8eea742 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -36,7 +36,7 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
   // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip)
 
   // CHECK: @builtins.property
-  // CHECK: def variadic1(self):
+  // CHECK: def variadic1(self) -> _ods_ir.OpOperandList:
   // CHECK:   operand_range = _ods_segmented_accessor(
   // CHECK:       self.operation.operands,
   // CHECK:       self.operation.attributes["operandSegmentSizes"], 0)
@@ -44,14 +44,14 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
   // CHECK-NOT: if len(operand_range)
   //
   // CHECK: @builtins.property
-  // CHECK: def non_variadic(self):
+  // CHECK: def non_variadic(self) -> _ods_ir.Value:
   // CHECK:   operand_range = _ods_segmented_accessor(
   // CHECK:       self.operation.operands,
   // CHECK:       self.operation.attributes["operandSegmentSizes"], 1)
   // CHECK:   return operand_range[0]
   //
   // CHECK: @builtins.property
-  // CHECK: def variadic2(self):
+  // CHECK: def variadic2(self) -> _Optional[_ods_ir.Value]:
   // CHECK:   operand_range = _ods_segmented_accessor(
   // CHECK:       self.operation.operands,
   // CHECK:       self.operation.attributes["operandSegmentSizes"], 2)
@@ -84,21 +84,21 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
   // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip)
 
   // CHECK: @builtins.property
-  // CHECK: def variadic1(self):
+  // CHECK: def variadic1(self) -> _Optional[_ods_ir.OpResult]:
   // CHECK:   result_range = _ods_segmented_accessor(
   // CHECK:       self.operation.results,
   // CHECK:       self.operation.attributes["resultSegmentSizes"], 0)
   // CHECK:   return result_range[0] if len(result_range) > 0 else None
   //
   // CHECK: @builtins.property
-  // CHECK: def non_variadic(self):
+  // CHECK: def non_variadic(self) -> _ods_ir.OpResult:
   // CHECK:   result_range = _ods_segmented_accessor(
   // CHECK:       self.operation.results,
   // CHECK:       self.operation.attributes["resultSegmentSizes"], 1)
   // CHECK:   return result_range[0]
   //
   // CHECK: @builtins.property
-  // CHECK: def variadic2(self):
+  // CHECK: def variadic2(self) -> _ods_ir.OpResultList:
   // CHECK:   result_range = _ods_segmented_accessor(
   // CHECK:       self.operation.results,
   // CHECK:       self.operation.attributes["resultSegmentSizes"], 2)
@@ -320,16 +320,16 @@ def MissingNamesOp : TestOp<"missing_names"> {
   // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip)
 
   // CHECK: @builtins.property
-  // CHECK: def f32(self):
+  // CHECK: def f32(self) -> _ods_ir.Value:
   // CHECK:   return self.operation.operands[1]
   let arguments = (ins I32, F32:$f32, I64);
 
   // CHECK: @builtins.property
-  // CHECK: def i32(self):
+  // CHECK: def i32(self) -> _ods_ir.OpResult:
   // CHECK:   return self.operation.results[0]
   //
   // CHECK: @builtins.property
-  // CHECK: def i64(self):
+  // CHECK: def i64(self) -> _ods_ir.OpResult:
   // CHECK:   return self.operation.results[2]
   let results = (outs I32:$i32, AnyFloat, I64:$i64);
 }
@@ -358,11 +358,11 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
   // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip)
 
   // CHECK: @builtins.property
-  // CHECK: def non_optional(self):
+  // CHECK: def non_optional(self) -> _ods_ir.Value:
   // CHECK:   return self.operation.operands[0]
 
   // CHECK: @builtins.property
-  // CHECK: def optional(self):
+  // CHECK: def optional(self) -> _Optional[_ods_ir.Value]:
   // CHECK:   return None if len(self.operation.operands) < 2 else self.operation.operands[1]
 }
 
@@ -389,11 +389,11 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
   // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip)
 
   // CHECK: @builtins.property
-  // CHECK: def non_variadic(self):
+  // CHECK: def non_variadic(self) -> _ods_ir.Value:
   // CHECK:   return self.operation.operands[0]
   //
   // CHECK: @builtins.property
-  // CHECK: def variadic(self):
+  // CHECK: def variadic(self) -> _ods_ir.OpOperandList:
   // CHECK:   _ods_variadic_group_length = len(self.operation.operands) - 2 + 1
   // CHECK:   return self.operation.operands[1:1 + _ods_variadic_group_length]
   let arguments = (ins AnyType:$non_variadic, Variadic<AnyType>:$variadic);
@@ -422,12 +422,12 @@ def OneVariadicResultOp : TestOp<"one_variadic_result"> {
   // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip)
 
   // CHECK: @builtins.property
-  // CHECK: def variadic(self):
+  // CHECK: def variadic(self) -> _ods_ir.OpResultList:
   // CHECK:   _ods_variadic_group_length = len(self.operation.results) - 2 + 1
   // CHECK:   return self.operation.results[0:0 + _ods_variadic_group_length]
   //
   // CHECK: @builtins.property
-  // CHECK: def non_variadic(self):
+  // CHECK: def non_variadic(self) -> _ods_ir.OpResult:
   // CHECK:   _ods_variadic_group_length = len(self.operation.results) - 2 + 1
   // CHECK:   return self.operation.results[1 + _ods_variadic_group_length - 1]
   let results = (outs Variadic<AnyType>:$variadic, AnyType:$non_variadic);
@@ -453,7 +453,7 @@ def PythonKeywordOp : TestOp<"python_keyword"> {
   // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip)
 
   // CHECK: @builtins.property
-  // CHECK: def in_(self):
+  // CHECK: def in_(self) -> _ods_ir.Value:
   // CHECK:   return self.operation.operands[0]
   let arguments = (ins AnyType:$in);
 }
@@ -491,17 +491,17 @@ def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResu
 def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
                                        [SameVariadicOperandSize]> {
   // CHECK: @builtins.property
-  // CHECK: def variadic1(self):
+  // CHECK: def variadic1(self) -> _ods_ir.OpOperandList:
   // CHECK:   start, elements_per_group = _ods_equally_sized_accessor(self.operation.operands, 1, 2, 0, 0)
   // CHECK:   return self.operation.operands[start:start + elements_per_group]
   //
   // CHECK: @builtins.property
-  // CHECK: def non_variadic(self):
+  // CHECK: def non_variadic(self) -> _ods_ir.Value:
   // CHECK:   start, elements_per_group = _ods_equally_sized_accessor(self.operation.operands, 1, 2, 0, 1)
   // CHECK:   return self.operation.operands[start]
   //
   // CHECK: @builtins.property
-  // CHECK: def variadic2(self):
+  // CHECK: def variadic2(self) -> _ods_ir.OpOperandList:
   // CHECK:   start, elements_per_group = _ods_equally_sized_accessor(self.operation.operands, 1, 2, 1, 1)
   // CHECK:   return self.operation.operands[start:start + elements_per_group]
   let arguments = (ins Variadic<AnyType>:$variadic1, AnyType:$non_variadic,
@@ -517,17 +517,17 @@ def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
 def SameVariadicResultSizeOp : TestOp<"same_variadic_result",
                                       [SameVariadicResultSize]> {
   // CHECK: @builtins.property
-  // CHECK: def variadic1(self):
+  // CHECK: def variadic1(self) -> _ods_ir.OpResultList:
   // CHECK:   start, elements_per_group = _ods_equally_sized_accessor(self.operation.results, 1, 2, 0, 0)
   // CHECK:   return self.operation.results[start:start + elements_per_group]
   //
   // CHECK: @builtins.property
-  // CHECK: def non_variadic(self):
+  // CHECK: def non_variadic(self) -> _ods_ir.OpResult:
   // CHECK:   start, elements_per_group = _ods_equally_sized_accessor(self.operation.results, 1, 2, 0, 1)
   // CHECK:   return self.operation.results[start]
   //
   // CHECK: @builtins.property
-  // CHECK: def variadic2(self):
+  // CHECK: def variadic2(self) -> _ods_ir.OpResultList:
   // CHECK:   start, elements_per_group = _ods_equally_sized_accessor(self.operation.results, 1, 2, 1, 1)
   // CHECK:   return self.operation.results[start:start + elements_per_group]
   let results = (outs Variadic<AnyType>:$variadic1, AnyType:$non_variadic,
@@ -557,20 +557,20 @@ def SimpleOp : TestOp<"simple"> {
   // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip)
 
   // CHECK: @builtins.property
-  // CHECK: def i32(self):
+  // CHECK: def i32(self) -> _ods_ir.Value:
   // CHECK:   return self.operation.operands[0]
   //
   // CHECK: @builtins.property
-  // CHECK: def f32(self):
+  // CHECK: def f32(self) -> _ods_ir.Value:
   // CHECK:   return self.operation.operands[1]
   let arguments = (ins I32:$i32, F32:$f32);
 
   // CHECK: @builtins.property
-  // CHECK: def i64(self):
+  // CHECK: def i64(self) -> _ods_ir.OpResult:
   // CHECK:   return self.operation.results[0]
   //
   // CHECK: @builtins.property
-  // CHECK: def f64(self):
+  // CHECK: def f64(self) -> _ods_ir.OpResult:
   // CHECK:   return self.operation.results[1]
   let results = (outs I64:$i64, AnyFloat:$f64);
 }
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 68262822ca6b5..faba6904870ac 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -323,6 +323,7 @@ def resultTypesDefinedByTraits():
             # CHECK: f32 index
             print(no_infer.single.type, no_infer.doubled.type)
 
+
 # CHECK-LABEL: TEST: testOptionalOperandOp
 @run
 def testOptionalOperandOp():
@@ -615,11 +616,17 @@ def values(lst):
                 [zero, one], two, [three, four]
             )
             # CHECK: Value(%{{.*}} = arith.constant 2 : i32)
-            print(variadic_operands.non_variadic)
+            non_variadic = variadic_operands.non_variadic
+            print(non_variadic)
+            assert isinstance(non_variadic, Value)
             # CHECK: ['Value(%{{.*}} = arith.constant 0 : i32)', 'Value(%{{.*}} = arith.constant 1 : i32)']
-            print(values(variadic_operands.variadic1))
+            variadic1 = variadic_operands.variadic1
+            print(values(variadic1))
+            assert isinstance(variadic1, OpOperandList)
             # CHECK: ['Value(%{{.*}} = arith.constant 3 : i32)', 'Value(%{{.*}} = arith.constant 4 : i32)']
-            print(values(variadic_operands.variadic2))
+            variadic2 = variadic_operands.variadic2
+            print(values(variadic2))
+            assert isinstance(variadic2, OpOperandList)
 
 
 # CHECK-LABEL: TEST: testVariadicResultAccess
@@ -660,7 +667,9 @@ def types(lst):
             # CHECK: i1
             print(op.non_variadic2.type)
             # CHECK: [IntegerType(i2), IntegerType(i3), IntegerType(i4)]
-            print(types(op.variadic))
+            variadic = op.variadic
+            print(types(variadic))
+            assert isinstance(variadic, OpResultList)
 
             #  Test Variadic-Variadic-Fixed
             op = test.SameVariadicResultSizeOpVVF(
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 6a7aa9e3432d5..17d23b3e3ac5e 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -44,7 +44,7 @@ _ods_ir = _ods_cext.ir
 _ods_cext.globals.register_traceback_file_exclusion(__file__)
 
 import builtins
-from typing import Sequence as _Sequence, Union as _Union
+from typing import Sequence as _Sequence, Union as _Union, Optional as _Optional
 
 )Py";
 
@@ -93,9 +93,10 @@ constexpr const char *opClassRegionSpecTemplate = R"Py(
 ///   {0} is the name of the accessor;
 ///   {1} is either 'operand' or 'result';
 ///   {2} is the position in the element list.
+///   {3} is the type hint.
 constexpr const char *opSingleTemplate = R"Py(
   @builtins.property
-  def {0}(self):
+  def {0}(self) -> {3}:
     return self.operation.{1}s[{2}]
 )Py";
 
@@ -104,11 +105,12 @@ constexpr const char *opSingleTemplate = R"Py(
 ///   {1} is either 'operand' or 'result';
 ///   {2} is the total number of element groups;
 ///   {3} is the position of the current group in the group list.
+///   {4} is the type hint.
 /// This works for both a single variadic group (non-negative length) and an
 /// single optional element (zero length if the element is absent).
 constexpr const char *opSingleAfterVariableTemplate = R"Py(
   @builtins.property
-  def {0}(self):
+  def {0}(self) -> {4}:
     _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
     return self.operation.{1}s[{3} + _ods_variadic_group_length - 1]
 )Py";
@@ -118,12 +120,13 @@ constexpr const char *opSingleAfterVariableTemplate = R"Py(
 ///   {1} is either 'operand' or 'result';
 ///   {2} is the total number of element groups;
 ///   {3} is the position of the current group in the group list.
+///   {4} is the type hint.
 /// This works if we have only one variable-length group (and it's the optional
 /// operand/result): we can deduce it's absent if the `len(operation.{1}s)` is
 /// smaller than the total number of groups.
 constexpr const char *opOneOptionalTemplate = R"Py(
   @builtins.property
-  def {0}(self):
+  def {0}(self) -> _Optional[{4}]:
     return None if len(self.operation.{1}s) < {2} else self.operation.{1}s[{3}]
 )Py";
 
@@ -132,9 +135,10 @@ constexpr const char *opOneOptionalTemplate = R"Py(
 ///   {1} is either 'operand' or 'result';
 ///   {2} is the total number of element groups;
 ///   {3} is the position of the current group in the group list.
+///   {4} is the type hint.
 constexpr const char *opOneVariadicTemplate = R"Py(
   @builtins.property
-  def {0}(self):
+  def {0}(self) -> {4}:
     _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
     return self.operation.{1}s[{3}:{3} + _ods_variadic_group_length]
 )Py";
@@ -146,9 +150,10 @@ constexpr const char *opOneVariadicTemplate = R"Py(
 ///   {3} is the total number of variadic groups;
 ///   {4} is the number of non-variadic groups preceding the current group;
 ///   {5} is the number of variadic groups preceding the current group.
+///   {6} is the type hint.
 constexpr const char *opVariadicEqualPrefixTemplate = R"Py(
   @builtins.property
-  def {0}(self):
+  def {0}(self) -> {6}:
     start, elements_per_group = _ods_equally_sized_accessor(self.operation.{1}s, {2}, {3}, {4}, {5}))Py";
 
 /// Second part of the template for equally-sized case, accessing a single
@@ -171,9 +176,10 @@ constexpr const char *opVariadicEqualVariadicTemplate = R"Py(
 ///   {2} is the position of the group in the group list;
 ///   {3} is a return suffix (expected [0] for single-element, empty for
 ///       variadic, and opVariadicSegmentOptionalTrailingTemplate for optional).
+///   {4} is the type hint.
 constexpr const char *opVariadicSegmentTemplate = R"Py(
   @builtins.property
-  def {0}(self):
+  def {0}(self) -> {4}:
     {1}_range = _ods_segmented_accessor(
          self.operation.{1}s,
          self.operation.attributes["{1}SegmentSizes"], {2})
@@ -357,15 +363,24 @@ static void emitElementAccessors(
         seenVariableLength = true;
       if (element.name.empty())
         continue;
+      const char *type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
+                                                           : "_ods_ir.OpResult";
       if (element.isVariableLength()) {
-        os << formatv(element.isOptional() ? opOneOptionalTemplate
-                                           : opOneVariadicTemplate,
-                      sanitizeName(element.name), kind, numElements, i);
+        if (element.isOptional()) {
+          os << formatv(opOneOptionalTemplate, sanitizeName(element.name), kind,
+                        numElements, i, type);
+        } else {
+          type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.OpOperandList"
+                                                   : "_ods_ir.OpResultList";
+          os << formatv(opOneVariadicTemplate, sanitizeName(element.name), kind,
+                        numElements, i, type);
+        }
       } else if (seenVariableLength) {
         os << formatv(opSingleAfterVariableTemplate, sanitizeName(element.name),
-                      kind, numElements, i);
+                      kind, numElements, i, type);
       } else {
-        os << formatv(opSingleTemplate, sanitizeName(element.name), kind, i);
+        os << formatv(opSingleTemplate, sanitizeName(element.name), kind, i,
+                      type);
       }
     }
     return;
@@ -388,9 +403,17 @@ static void emitElementAccessors(
     for (unsigned i = 0; i < numElements; ++i) {
       const NamedTypeConstraint &element = getElement(op, i);
       if (!element.name.empty()) {
+        std::string type;
+        if (element.isVariableLength()) {
+          type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.OpOperandList"
+                                                   : "_ods_ir.OpResultList";
+        } else {
+          type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
+                                                   : "_ods_ir.OpResult";
+        }
         os << formatv(opVariadicEqualPrefixTemplate, sanitizeName(element.name),
                       kind, numSimpleLength, numVariadicGroups,
-                      numPrecedingSimple, numPrecedingVariadic);
+                      numPrecedingSimple, numPrecedingVariadic, type);
         os << formatv(element.isVariableLength()
                           ? opVariadicEqualVariadicTemplate
                           : opVariadicEqualSimpleTemplate,
@@ -413,13 +436,23 @@ static void emitElementAccessors(
       if (element.name.empty())
         continue;
       std::string trailing;
-      if (!element.isVariableLength())
-        trailing = "[0]";
-      else if (element.isOptional())
-        trailing = std::string(
-            formatv(opVariadicSegmentOptionalTrailingTemplate, kind));
+      std::string type = std::strcmp(kind, "operand") == 0
+                             ? "_ods_ir.OpOperandList"
+                             : "_ods_ir.OpResultList";
+      if (!element.isVariableLength() || element.isOptional()) {
+        type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
+                                                 : "_ods_ir.OpResult";
+        if (!element.isVariableLength()) {
+          trailing = "[0]";
+        } else if (element.isOptional()) {
+          type = "_Optional[" + type + "]";
+          trailing = std::string(
+              formatv(opVariadicSegmentOptionalTrailingTemplate, kind));
+        }
+      }
+
       os << formatv(opVariadicSegmentTemplate, sanitizeName(element.name), kind,
-                    i, trailing);
+                    i, trailing, type);
     }
     return;
   }

@makslevental makslevental force-pushed the users/makslevental/fix-accessors branch 3 times, most recently from 72d3197 to c986aac Compare September 15, 2025 00:50
Copy link

github-actions bot commented Sep 15, 2025

✅ With the latest revision this PR passed the Python code formatter.

@makslevental makslevental force-pushed the users/makslevental/fix-accessors branch from c986aac to 87b5257 Compare September 15, 2025 00:54
Copy link
Contributor

@superbobry superbobry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, but I think it would be more helpful to have op-specific subtypes instead of just ir.Attribute for attributes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I think you can use | for both Union and Optional. We were doing that in the hand-written ir.pyi.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pipe syntax is supported only as of 3.9, which is actually EOL as of next month but still easier to just keep it this way.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we need to probably restart the discussion to bump to newer minimum version as our current one is very old ...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use a more specific type for attributes here and below?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not straightforward but sure lemme try.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay the attributes are now concrete @superbobry can you give it a spot check to make sure I didn't mess anything up...

@makslevental makslevental force-pushed the users/makslevental/fix-accessors branch 2 times, most recently from 6b1a72d to 274f5d3 Compare September 15, 2025 22:22
@makslevental makslevental force-pushed the users/makslevental/fix-accessors branch from 274f5d3 to 2033bb6 Compare September 15, 2025 22:25
Copy link

github-actions bot commented Sep 15, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@makslevental makslevental force-pushed the users/makslevental/fix-accessors branch 3 times, most recently from b7f2573 to a3b9ad8 Compare September 15, 2025 23:37
@makslevental makslevental force-pushed the users/makslevental/fix-accessors branch from a3b9ad8 to c2bcd9c Compare September 15, 2025 23:43
Copy link
Contributor

@ingomueller-net ingomueller-net left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing -- I love it!

OoC: Are you testing this somehow, maybe with pyright or mypy? Should we set up some CI or CMake targets (eventually) or at least document this somehow?

}

static std::string getPythonAttrName(mlir::tblgen::Attribute attr) {
auto storageTypeStr = attr.getStorageType();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is fine for the initial version, but we might also consider having an extensible system based on tblgen. @ftynse had some ideas on how to do that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd love to hear some of these ideas - the only thing I managed to think of was forking this code downstream 🤷‍♂️

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC one of the ideas was to have the mapping in a .td file instead of C++ :)

Copy link
Contributor Author

@makslevental makslevental Sep 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea I think that's a good idea - specifically an optional field on TD types that's something like python_type_annotation. Of course, as with all things "core", I wonder if there would be complaints 🤷.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd do it standalone from the TD type, but in the TD processed by python generator and only it. Its just a mapping there. We did something like this for TFLite ages ago, but can't find it quickly now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

meaning specify a mapping using tablgen hmm. okay I guess that's feasible - something like

class PythonBindingTypeHint<Type mlirType, string pythonType>...
def YourDownstreamPythonHint<YourDownstreamType, "your.downstream.type.hint">

I can do that in a follow-up - for right now let's just fix these for upstream attrs.

@makslevental
Copy link
Contributor Author

makslevental commented Sep 16, 2025

Amazing -- I love it!

OoC: Are you testing this somehow, maybe with pyright or mypy? Should we set up some CI or CMake targets (eventually) or at least document this somehow?

#157569

🙂

OoC: Are you testing this somehow

Actually there's a test here right which manually tests that the annotations actually match the return types at least for the cases tested by python_test (which is close to all I believe).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we need to probably restart the discussion to bump to newer minimum version as our current one is very old ...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is _ods_ir etc always sufficient? (was thinking how C++ side we end up doing ::mlir:: to ensure namespace, and not sure if we could run into naming conflicts here - that being said _ods_ir should have very low overlap chance coincidentally)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_ods_ir is how the ir module is already being referenced in these generated files:

}

static std::string getPythonAttrName(mlir::tblgen::Attribute attr) {
auto storageTypeStr = attr.getStorageType();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd do it standalone from the TD type, but in the TD processed by python generator and only it. Its just a mapping there. We did something like this for TFLite ages ago, but can't find it quickly now.

@makslevental makslevental force-pushed the users/makslevental/fix-accessors branch 9 times, most recently from 3a765b2 to d79db7a Compare September 17, 2025 21:23
@makslevental makslevental force-pushed the users/makslevental/fix-accessors branch from d79db7a to 61eb478 Compare September 17, 2025 21:25
@makslevental makslevental merged commit 67f43c6 into llvm:main Sep 19, 2025
9 checks passed
@makslevental makslevental deleted the users/makslevental/fix-accessors branch September 19, 2025 02:12
SeongjaeP pushed a commit to SeongjaeP/llvm-project that referenced this pull request Sep 23, 2025
This PR adds type hints for accessors in the generated builders.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants