Skip to content

Commit 55d9c91

Browse files
authored
[MLIR][Python] Add optional results parameter for building op with inferable result types (#156818)
Currently in MLIR python bindings, operations with inferable result types (e.g. with `InferTypeOpInterface` or `SameOperandsAndResultType`) will generate such builder functions: ```python def my_op(arg1, arg2 .. argN, *, loc=None, ip=None): ... # result types will be inferred automatically ``` However, in some cases we may want to provide the result types explicitly. For example, the implementation of interface method `inferResultTypes(..)` can return a failure and then we cannot build the op in that way. Also, in the C++ side we have multiple `build` methods for both explicitly specify the result types and automatically inferring them. In this PR, we change the signature of this builder function to: ```python def my_op(arg1, arg2 .. argN, *, results=None, loc=None, ip=None): ... # result types will be inferred automatically if results is None ``` If the `results` is not provided, it will be inferred automatically, otherwise the provided result types will be utilized. Also, `__init__` methods of the generated op classes are changed correspondingly. Note that for operations without inferable result types, the signature remain unchanged, i.e. `def my_op(res1 .. resN, arg1 .. argN, *, loc=None, ip=None)`. --- Previously I have considered an approach like `my_op(arg, *, res1=None, res2=None, loc=None, ip=None)`, but I quickly realized it had some issues. For example, if the user only provides some of the arguments—say `my_op(v1, res1=i32)`—this could lead to problems. Moreover, we don’t seem to have a mechanism for inferring only part of result types. A unified `results` parameter seems to be more simple and straightforward.
1 parent 882575f commit 55d9c91

File tree

4 files changed

+76
-56
lines changed

4 files changed

+76
-56
lines changed

mlir/test/mlir-tblgen/op-python-bindings.td

Lines changed: 33 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
2323
[AttrSizedOperandSegments]> {
2424
// CHECK: def __init__(self, variadic1, non_variadic, *, variadic2=None, loc=None, ip=None):
2525
// CHECK: operands = []
26-
// CHECK: results = []
2726
// CHECK: attributes = {}
2827
// CHECK: regions = None
2928
// CHECK: operands.append(_get_op_results_or_values(variadic1))
3029
// CHECK: operands.append(non_variadic)
3130
// CHECK: operands.append(variadic2)
31+
// CHECK: results = []
3232
// CHECK: _ods_successors = None
3333
// CHECK: super().__init__(
3434
// CHECK: self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS,
@@ -71,9 +71,9 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
7171
[AttrSizedResultSegments]> {
7272
// CHECK: def __init__(self, variadic1, non_variadic, variadic2, *, loc=None, ip=None):
7373
// CHECK: operands = []
74-
// CHECK: results = []
7574
// CHECK: attributes = {}
7675
// CHECK: regions = None
76+
// CHECK: results = []
7777
// CHECK: if variadic1 is not None: results.append(variadic1)
7878
// CHECK: results.append(non_variadic)
7979
// CHECK: results.append(variadic2)
@@ -120,7 +120,6 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
120120
def AttributedOp : TestOp<"attributed_op"> {
121121
// CHECK: def __init__(self, i32attr, in_, *, optionalF32Attr=None, unitAttr=None, loc=None, ip=None):
122122
// CHECK: operands = []
123-
// CHECK: results = []
124123
// CHECK: attributes = {}
125124
// CHECK: regions = None
126125
// CHECK: attributes["i32attr"] = (i32attr if (
@@ -131,6 +130,7 @@ def AttributedOp : TestOp<"attributed_op"> {
131130
// CHECK: if bool(unitAttr): attributes["unitAttr"] = _ods_ir.UnitAttr.get(
132131
// CHECK: _ods_get_default_loc_context(loc))
133132
// CHECK: attributes["in"] = (in_
133+
// CHECK: results = []
134134
// CHECK: _ods_successors = None
135135
// CHECK: super().__init__(
136136
// CHECK: self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS,
@@ -170,14 +170,14 @@ def AttributedOp : TestOp<"attributed_op"> {
170170
def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
171171
// CHECK: def __init__(self, _gen_arg_0, _gen_arg_2, *, in_=None, is_=None, loc=None, ip=None):
172172
// CHECK: operands = []
173-
// CHECK: results = []
174173
// CHECK: attributes = {}
175174
// CHECK: regions = None
176175
// CHECK: operands.append(_gen_arg_0)
177176
// CHECK: operands.append(_gen_arg_2)
178177
// CHECK: if bool(in_): attributes["in"] = _ods_ir.UnitAttr.get(
179178
// CHECK: _ods_get_default_loc_context(loc))
180179
// CHECK: if is_ is not None: attributes["is"] = (is_
180+
// CHECK: results = []
181181
// CHECK: _ods_successors = None
182182
// CHECK: super().__init__(
183183
// CHECK: self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS,
@@ -205,11 +205,11 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
205205
def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> {
206206
// CHECK: def __init__(self, *, arr=None, unsupported=None, loc=None, ip=None):
207207
// CHECK: operands = []
208-
// CHECK: results = []
209208
// CHECK: attributes = {}
210209
// CHECK: regions = None
211210
// CHECK: if arr is not None: attributes["arr"] = (arr
212211
// CHECK: if unsupported is not None: attributes["unsupported"] = (unsupported
212+
// CHECK: results = []
213213
// CHECK: _ods_successors = None
214214
// CHECK: super().__init__(
215215
// CHECK: self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS,
@@ -226,21 +226,21 @@ def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> {
226226

227227
// CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_op"
228228
def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResultType]> {
229-
// CHECK: def __init__(self, type_, *, loc=None, ip=None):
229+
// CHECK: def __init__(self, type_, *, results=None, loc=None, ip=None):
230230
// CHECK: operands = []
231-
// CHECK: results = []
232-
// CHECK: _ods_result_type_source_attr = attributes["type"]
233-
// CHECK: _ods_derived_result_type = (
231+
// CHECK: if results is None:
232+
// CHECK: _ods_result_type_source_attr = attributes["type"]
233+
// CHECK: _ods_derived_result_type = (
234234
// CHECK: _ods_ir.TypeAttr(_ods_result_type_source_attr).value
235235
// CHECK: if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else
236236
// CHECK: _ods_result_type_source_attr.type)
237-
// CHECK: results.extend([_ods_derived_result_type] * 2)
237+
// CHECK: results = [_ods_derived_result_type] * 2
238238
let arguments = (ins TypeAttr:$type);
239239
let results = (outs AnyType:$res, AnyType);
240240
}
241241

242-
// CHECK: def derive_result_types_op(type_, *, loc=None, ip=None)
243-
// CHECK: return DeriveResultTypesOp(type_=type_, loc=loc, ip=ip).results
242+
// CHECK: def derive_result_types_op(type_, *, results=None, loc=None, ip=None)
243+
// CHECK: return DeriveResultTypesOp(type_=type_, results=results, loc=loc, ip=ip).results
244244

245245
// CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_variadic_op"
246246
def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [FirstAttrDerivedResultType]> {
@@ -258,9 +258,9 @@ def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [Fir
258258
def EmptyOp : TestOp<"empty">;
259259
// CHECK: def __init__(self, *, loc=None, ip=None):
260260
// CHECK: operands = []
261-
// CHECK: results = []
262261
// CHECK: attributes = {}
263262
// CHECK: regions = None
263+
// CHECK: results = []
264264
// CHECK: _ods_successors = None
265265
// CHECK: super().__init__(
266266
// CHECK: self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS,
@@ -272,44 +272,44 @@ def EmptyOp : TestOp<"empty">;
272272

273273
// CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_implied_op"
274274
def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
275-
// CHECK: def __init__(self, *, loc=None, ip=None):
275+
// CHECK: def __init__(self, *, results=None, loc=None, ip=None):
276276
// CHECK: _ods_context = _ods_get_default_loc_context(loc)
277277
// CHECK: super().__init__(
278278
// CHECK: self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS,
279-
// CHECK: attributes=attributes, operands=operands,
279+
// CHECK: attributes=attributes, results=results, operands=operands,
280280
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)
281281
let results = (outs I32:$i32, F32:$f32);
282282
}
283283

284-
// CHECK: def infer_result_types_implied_op(*, loc=None, ip=None)
285-
// CHECK: return InferResultTypesImpliedOp(loc=loc, ip=ip).results
284+
// CHECK: def infer_result_types_implied_op(*, results=None, loc=None, ip=None)
285+
// CHECK: return InferResultTypesImpliedOp(results=results, loc=loc, ip=ip).results
286286

287287
// CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_op"
288288
def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]> {
289-
// CHECK: def __init__(self, *, loc=None, ip=None):
289+
// CHECK: def __init__(self, *, results=None, loc=None, ip=None):
290290
// CHECK: operands = []
291291
// CHECK: super().__init__(
292292
// CHECK: self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS,
293-
// CHECK: attributes=attributes, operands=operands,
293+
// CHECK: attributes=attributes, results=results, operands=operands,
294294
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)
295295
let results = (outs AnyType, AnyType, AnyType);
296296
}
297297

298-
// CHECK: def infer_result_types_op(*, loc=None, ip=None)
299-
// CHECK: return InferResultTypesOp(loc=loc, ip=ip).results
298+
// CHECK: def infer_result_types_op(*, results=None, loc=None, ip=None)
299+
// CHECK: return InferResultTypesOp(results=results, loc=loc, ip=ip).results
300300

301301
// CHECK: @_ods_cext.register_operation(_Dialect)
302302
// CHECK: class MissingNamesOp(_ods_ir.OpView):
303303
// CHECK-LABEL: OPERATION_NAME = "test.missing_names"
304304
def MissingNamesOp : TestOp<"missing_names"> {
305305
// CHECK: def __init__(self, i32, _gen_res_1, i64, _gen_arg_0, f32, _gen_arg_2, *, loc=None, ip=None):
306306
// CHECK: operands = []
307-
// CHECK: results = []
308307
// CHECK: attributes = {}
309308
// CHECK: regions = None
310309
// CHECK: operands.append(_gen_arg_0)
311310
// CHECK: operands.append(f32)
312311
// CHECK: operands.append(_gen_arg_2)
312+
// CHECK: results = []
313313
// CHECK: results.append(i32)
314314
// CHECK: results.append(_gen_res_1)
315315
// CHECK: results.append(i64)
@@ -346,11 +346,11 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
346346
let arguments = (ins AnyType:$non_optional, Optional<AnyType>:$optional);
347347
// CHECK: def __init__(self, non_optional, *, optional=None, loc=None, ip=None):
348348
// CHECK: operands = []
349-
// CHECK: results = []
350349
// CHECK: attributes = {}
351350
// CHECK: regions = None
352351
// CHECK: operands.append(non_optional)
353352
// CHECK: if optional is not None: operands.append(optional)
353+
// CHECK: results = []
354354
// CHECK: _ods_successors = None
355355
// CHECK: super().__init__(
356356
// CHECK: self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS
@@ -377,11 +377,11 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
377377
def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
378378
// CHECK: def __init__(self, non_variadic, variadic, *, loc=None, ip=None):
379379
// CHECK: operands = []
380-
// CHECK: results = []
381380
// CHECK: attributes = {}
382381
// CHECK: regions = None
383382
// CHECK: operands.append(non_variadic)
384383
// CHECK: operands.extend(_get_op_results_or_values(variadic))
384+
// CHECK: results = []
385385
// CHECK: _ods_successors = None
386386
// CHECK: super().__init__(
387387
// CHECK: self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS
@@ -410,9 +410,9 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
410410
def OneVariadicResultOp : TestOp<"one_variadic_result"> {
411411
// CHECK: def __init__(self, variadic, non_variadic, *, loc=None, ip=None):
412412
// CHECK: operands = []
413-
// CHECK: results = []
414413
// CHECK: attributes = {}
415414
// CHECK: regions = None
415+
// CHECK: results = []
416416
// CHECK: results.extend(variadic)
417417
// CHECK: results.append(non_variadic)
418418
// CHECK: _ods_successors = None
@@ -442,10 +442,10 @@ def OneVariadicResultOp : TestOp<"one_variadic_result"> {
442442
def PythonKeywordOp : TestOp<"python_keyword"> {
443443
// CHECK: def __init__(self, in_, *, loc=None, ip=None):
444444
// CHECK: operands = []
445-
// CHECK: results = []
446445
// CHECK: attributes = {}
447446
// CHECK: regions = None
448447
// CHECK: operands.append(in_)
448+
// CHECK: results = []
449449
// CHECK: _ods_successors = None
450450
// CHECK: super().__init__(
451451
// CHECK: self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS
@@ -463,17 +463,16 @@ def PythonKeywordOp : TestOp<"python_keyword"> {
463463

464464
// CHECK-LABEL: OPERATION_NAME = "test.same_results"
465465
def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
466-
// CHECK: def __init__(self, in1, in2, *, loc=None, ip=None):
466+
// CHECK: def __init__(self, in1, in2, *, results=None, loc=None, ip=None):
467467
// CHECK: operands = []
468-
// CHECK: results = []
469468
// CHECK: operands.append
470-
// CHECK: results.extend([operands[0].type] * 1)
469+
// CHECK: if results is None: results = [operands[0].type] * 1
471470
let arguments = (ins AnyType:$in1, AnyType:$in2);
472471
let results = (outs AnyType:$res);
473472
}
474473

475-
// CHECK: def same_results(in1, in2, *, loc=None, ip=None)
476-
// CHECK: return SameResultsOp(in1=in1, in2=in2, loc=loc, ip=ip)
474+
// CHECK: def same_results(in1, in2, *, results=None, loc=None, ip=None)
475+
// CHECK: return SameResultsOp(in1=in1, in2=in2, results=results, loc=loc, ip=ip)
477476

478477
// CHECK-LABEL: OPERATION_NAME = "test.same_results_variadic"
479478
def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResultType]> {
@@ -544,11 +543,11 @@ def SameVariadicResultSizeOp : TestOp<"same_variadic_result",
544543
def SimpleOp : TestOp<"simple"> {
545544
// CHECK: def __init__(self, i64, f64, i32, f32, *, loc=None, ip=None):
546545
// CHECK: operands = []
547-
// CHECK: results = []
548546
// CHECK: attributes = {}
549547
// CHECK: regions = None
550548
// CHECK: operands.append(i32)
551549
// CHECK: operands.append(f32)
550+
// CHECK: results = []
552551
// CHECK: results.append(i64)
553552
// CHECK: results.append(f64)
554553
// CHECK: _ods_successors = None
@@ -584,9 +583,9 @@ def SimpleOp : TestOp<"simple"> {
584583
def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
585584
// CHECK: def __init__(self, num_variadic, *, loc=None, ip=None):
586585
// CHECK: operands = []
587-
// CHECK: results = []
588586
// CHECK: attributes = {}
589587
// CHECK: regions = None
588+
// CHECK: results = []
590589
// CHECK: _ods_successors = None
591590
// CHECK: regions = 2 + num_variadic
592591
// CHECK: super().__init__(
@@ -612,9 +611,9 @@ def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
612611
def VariadicRegionOp : TestOp<"variadic_region"> {
613612
// CHECK: def __init__(self, num_variadic, *, loc=None, ip=None):
614613
// CHECK: operands = []
615-
// CHECK: results = []
616614
// CHECK: attributes = {}
617615
// CHECK: regions = None
616+
// CHECK: results = []
618617
// CHECK: _ods_successors = None
619618
// CHECK: regions = 0 + num_variadic
620619
// CHECK: super().__init__(

mlir/test/python/dialects/python_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,10 @@ def resultTypesDefinedByTraits():
283283
module = Module.create()
284284
with InsertionPoint(module.body):
285285
inferred = test.InferResultsOp()
286+
287+
# CHECK: i32 i64
288+
print(inferred.single.type, inferred.doubled.type)
289+
286290
same = test.SameOperandAndResultTypeOp([inferred.results[0]])
287291
# CHECK-COUNT-2: i32
288292
print(same.one.type)
@@ -309,6 +313,15 @@ def resultTypesDefinedByTraits():
309313
# CHECK: index
310314
print(implied.index.type)
311315

316+
# provide the result types to avoid inferring them
317+
f64 = F64Type.get()
318+
no_imply = test.InferResultsImpliedOp(results=[f64, f64, f64])
319+
# CHECK-COUNT-3: f64
320+
print(no_imply.integer.type, no_imply.flt.type, no_imply.index.type)
321+
322+
no_infer = test.InferResultsOp(results=[F32Type.get(), IndexType.get()])
323+
# CHECK: f32 index
324+
print(no_infer.single.type, no_infer.doubled.type)
312325

313326
# CHECK-LABEL: TEST: testOptionalOperandOp
314327
@run

mlir/test/python/ir/auto_location.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def testInferLocations():
5151
_cext.globals.register_traceback_file_inclusion(_arith_ops_gen.__file__)
5252
three = arith.constant(IndexType.get(), 3)
5353
# fmt: off
54-
# CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":405:4 to :218) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":52:16 to :50) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4)))))
54+
# CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":397:4 to :235) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":52:16 to :50) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4)))))
5555
# fmt: on
5656
print(three.location)
5757

@@ -60,14 +60,14 @@ def foo():
6060
print(four.location)
6161

6262
# fmt: off
63-
# CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":405:4 to :218) at callsite("testInferLocations.<locals>.foo"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":59:19 to :53) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":65:8 to :13) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4))))))
63+
# CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":397:4 to :235) at callsite("testInferLocations.<locals>.foo"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":59:19 to :53) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":65:8 to :13) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4))))))
6464
# fmt: on
6565
foo()
6666

6767
_cext.globals.register_traceback_file_exclusion(__file__)
6868

6969
# fmt: off
70-
# CHECK: loc("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":405:4 to :218))
70+
# CHECK: loc("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":397:4 to :235))
7171
# fmt: on
7272
foo()
7373

0 commit comments

Comments
 (0)