Skip to content

Commit 87b5257

Browse files
committed
[MLIR][Python] add type hints for accessors
1 parent 30e9cba commit 87b5257

File tree

4 files changed

+123
-64
lines changed

4 files changed

+123
-64
lines changed

mlir/test/mlir-tblgen/op-python-bindings.td

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -36,22 +36,22 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
3636
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)
3737

3838
// CHECK: @builtins.property
39-
// CHECK: def variadic1(self):
39+
// CHECK: def variadic1(self) -> _ods_ir.OpOperandList:
4040
// CHECK: operand_range = _ods_segmented_accessor(
4141
// CHECK: self.operation.operands,
4242
// CHECK: self.operation.attributes["operandSegmentSizes"], 0)
4343
// CHECK: return operand_range
4444
// CHECK-NOT: if len(operand_range)
4545
//
4646
// CHECK: @builtins.property
47-
// CHECK: def non_variadic(self):
47+
// CHECK: def non_variadic(self) -> _ods_ir.Value:
4848
// CHECK: operand_range = _ods_segmented_accessor(
4949
// CHECK: self.operation.operands,
5050
// CHECK: self.operation.attributes["operandSegmentSizes"], 1)
5151
// CHECK: return operand_range[0]
5252
//
5353
// CHECK: @builtins.property
54-
// CHECK: def variadic2(self):
54+
// CHECK: def variadic2(self) -> _Optional[_ods_ir.Value]:
5555
// CHECK: operand_range = _ods_segmented_accessor(
5656
// CHECK: self.operation.operands,
5757
// CHECK: self.operation.attributes["operandSegmentSizes"], 2)
@@ -84,21 +84,21 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
8484
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)
8585

8686
// CHECK: @builtins.property
87-
// CHECK: def variadic1(self):
87+
// CHECK: def variadic1(self) -> _Optional[_ods_ir.OpResult]:
8888
// CHECK: result_range = _ods_segmented_accessor(
8989
// CHECK: self.operation.results,
9090
// CHECK: self.operation.attributes["resultSegmentSizes"], 0)
9191
// CHECK: return result_range[0] if len(result_range) > 0 else None
9292
//
9393
// CHECK: @builtins.property
94-
// CHECK: def non_variadic(self):
94+
// CHECK: def non_variadic(self) -> _ods_ir.OpResult:
9595
// CHECK: result_range = _ods_segmented_accessor(
9696
// CHECK: self.operation.results,
9797
// CHECK: self.operation.attributes["resultSegmentSizes"], 1)
9898
// CHECK: return result_range[0]
9999
//
100100
// CHECK: @builtins.property
101-
// CHECK: def variadic2(self):
101+
// CHECK: def variadic2(self) -> _ods_ir.OpResultList:
102102
// CHECK: result_range = _ods_segmented_accessor(
103103
// CHECK: self.operation.results,
104104
// CHECK: self.operation.attributes["resultSegmentSizes"], 2)
@@ -138,21 +138,21 @@ def AttributedOp : TestOp<"attributed_op"> {
138138
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)
139139

140140
// CHECK: @builtins.property
141-
// CHECK: def i32attr(self):
141+
// CHECK: def i32attr(self) -> _ods_ir.Attribute:
142142
// CHECK: return self.operation.attributes["i32attr"]
143143

144144
// CHECK: @builtins.property
145-
// CHECK: def optionalF32Attr(self):
145+
// CHECK: def optionalF32Attr(self) -> _Optional[_ods_ir.Attribute]:
146146
// CHECK: if "optionalF32Attr" not in self.operation.attributes:
147147
// CHECK: return None
148148
// CHECK: return self.operation.attributes["optionalF32Attr"]
149149

150150
// CHECK: @builtins.property
151-
// CHECK: def unitAttr(self):
151+
// CHECK: def unitAttr(self) -> bool:
152152
// CHECK: return "unitAttr" in self.operation.attributes
153153

154154
// CHECK: @builtins.property
155-
// CHECK: def in_(self):
155+
// CHECK: def in_(self) -> _ods_ir.Attribute:
156156
// CHECK: return self.operation.attributes["in"]
157157

158158
let arguments = (ins I32Attr:$i32attr, OptionalAttr<F32Attr>:$optionalF32Attr,
@@ -185,11 +185,11 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
185185
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)
186186

187187
// CHECK: @builtins.property
188-
// CHECK: def in_(self):
188+
// CHECK: def in_(self) -> bool:
189189
// CHECK: return "in" in self.operation.attributes
190190

191191
// CHECK: @builtins.property
192-
// CHECK: def is_(self):
192+
// CHECK: def is_(self) -> _Optional[_ods_ir.Attribute]:
193193
// CHECK: if "is" not in self.operation.attributes:
194194
// CHECK: return None
195195
// CHECK: return self.operation.attributes["is"]
@@ -320,16 +320,16 @@ def MissingNamesOp : TestOp<"missing_names"> {
320320
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)
321321

322322
// CHECK: @builtins.property
323-
// CHECK: def f32(self):
323+
// CHECK: def f32(self) -> _ods_ir.Value:
324324
// CHECK: return self.operation.operands[1]
325325
let arguments = (ins I32, F32:$f32, I64);
326326

327327
// CHECK: @builtins.property
328-
// CHECK: def i32(self):
328+
// CHECK: def i32(self) -> _ods_ir.OpResult:
329329
// CHECK: return self.operation.results[0]
330330
//
331331
// CHECK: @builtins.property
332-
// CHECK: def i64(self):
332+
// CHECK: def i64(self) -> _ods_ir.OpResult:
333333
// CHECK: return self.operation.results[2]
334334
let results = (outs I32:$i32, AnyFloat, I64:$i64);
335335
}
@@ -358,11 +358,11 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
358358
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)
359359

360360
// CHECK: @builtins.property
361-
// CHECK: def non_optional(self):
361+
// CHECK: def non_optional(self) -> _ods_ir.Value:
362362
// CHECK: return self.operation.operands[0]
363363

364364
// CHECK: @builtins.property
365-
// CHECK: def optional(self):
365+
// CHECK: def optional(self) -> _Optional[_ods_ir.Value]:
366366
// CHECK: return None if len(self.operation.operands) < 2 else self.operation.operands[1]
367367
}
368368

@@ -389,11 +389,11 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
389389
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)
390390

391391
// CHECK: @builtins.property
392-
// CHECK: def non_variadic(self):
392+
// CHECK: def non_variadic(self) -> _ods_ir.Value:
393393
// CHECK: return self.operation.operands[0]
394394
//
395395
// CHECK: @builtins.property
396-
// CHECK: def variadic(self):
396+
// CHECK: def variadic(self) -> _ods_ir.OpOperandList:
397397
// CHECK: _ods_variadic_group_length = len(self.operation.operands) - 2 + 1
398398
// CHECK: return self.operation.operands[1:1 + _ods_variadic_group_length]
399399
let arguments = (ins AnyType:$non_variadic, Variadic<AnyType>:$variadic);
@@ -422,12 +422,12 @@ def OneVariadicResultOp : TestOp<"one_variadic_result"> {
422422
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)
423423

424424
// CHECK: @builtins.property
425-
// CHECK: def variadic(self):
425+
// CHECK: def variadic(self) -> _ods_ir.OpResultList:
426426
// CHECK: _ods_variadic_group_length = len(self.operation.results) - 2 + 1
427427
// CHECK: return self.operation.results[0:0 + _ods_variadic_group_length]
428428
//
429429
// CHECK: @builtins.property
430-
// CHECK: def non_variadic(self):
430+
// CHECK: def non_variadic(self) -> _ods_ir.OpResult:
431431
// CHECK: _ods_variadic_group_length = len(self.operation.results) - 2 + 1
432432
// CHECK: return self.operation.results[1 + _ods_variadic_group_length - 1]
433433
let results = (outs Variadic<AnyType>:$variadic, AnyType:$non_variadic);
@@ -453,7 +453,7 @@ def PythonKeywordOp : TestOp<"python_keyword"> {
453453
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)
454454

455455
// CHECK: @builtins.property
456-
// CHECK: def in_(self):
456+
// CHECK: def in_(self) -> _ods_ir.Value:
457457
// CHECK: return self.operation.operands[0]
458458
let arguments = (ins AnyType:$in);
459459
}
@@ -491,17 +491,17 @@ def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResu
491491
def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
492492
[SameVariadicOperandSize]> {
493493
// CHECK: @builtins.property
494-
// CHECK: def variadic1(self):
494+
// CHECK: def variadic1(self) -> _ods_ir.OpOperandList:
495495
// CHECK: start, elements_per_group = _ods_equally_sized_accessor(self.operation.operands, 1, 2, 0, 0)
496496
// CHECK: return self.operation.operands[start:start + elements_per_group]
497497
//
498498
// CHECK: @builtins.property
499-
// CHECK: def non_variadic(self):
499+
// CHECK: def non_variadic(self) -> _ods_ir.Value:
500500
// CHECK: start, elements_per_group = _ods_equally_sized_accessor(self.operation.operands, 1, 2, 0, 1)
501501
// CHECK: return self.operation.operands[start]
502502
//
503503
// CHECK: @builtins.property
504-
// CHECK: def variadic2(self):
504+
// CHECK: def variadic2(self) -> _ods_ir.OpOperandList:
505505
// CHECK: start, elements_per_group = _ods_equally_sized_accessor(self.operation.operands, 1, 2, 1, 1)
506506
// CHECK: return self.operation.operands[start:start + elements_per_group]
507507
let arguments = (ins Variadic<AnyType>:$variadic1, AnyType:$non_variadic,
@@ -517,17 +517,17 @@ def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
517517
def SameVariadicResultSizeOp : TestOp<"same_variadic_result",
518518
[SameVariadicResultSize]> {
519519
// CHECK: @builtins.property
520-
// CHECK: def variadic1(self):
520+
// CHECK: def variadic1(self) -> _ods_ir.OpResultList:
521521
// CHECK: start, elements_per_group = _ods_equally_sized_accessor(self.operation.results, 1, 2, 0, 0)
522522
// CHECK: return self.operation.results[start:start + elements_per_group]
523523
//
524524
// CHECK: @builtins.property
525-
// CHECK: def non_variadic(self):
525+
// CHECK: def non_variadic(self) -> _ods_ir.OpResult:
526526
// CHECK: start, elements_per_group = _ods_equally_sized_accessor(self.operation.results, 1, 2, 0, 1)
527527
// CHECK: return self.operation.results[start]
528528
//
529529
// CHECK: @builtins.property
530-
// CHECK: def variadic2(self):
530+
// CHECK: def variadic2(self) -> _ods_ir.OpResultList:
531531
// CHECK: start, elements_per_group = _ods_equally_sized_accessor(self.operation.results, 1, 2, 1, 1)
532532
// CHECK: return self.operation.results[start:start + elements_per_group]
533533
let results = (outs Variadic<AnyType>:$variadic1, AnyType:$non_variadic,
@@ -557,20 +557,20 @@ def SimpleOp : TestOp<"simple"> {
557557
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)
558558

559559
// CHECK: @builtins.property
560-
// CHECK: def i32(self):
560+
// CHECK: def i32(self) -> _ods_ir.Value:
561561
// CHECK: return self.operation.operands[0]
562562
//
563563
// CHECK: @builtins.property
564-
// CHECK: def f32(self):
564+
// CHECK: def f32(self) -> _ods_ir.Value:
565565
// CHECK: return self.operation.operands[1]
566566
let arguments = (ins I32:$i32, F32:$f32);
567567

568568
// CHECK: @builtins.property
569-
// CHECK: def i64(self):
569+
// CHECK: def i64(self) -> _ods_ir.OpResult:
570570
// CHECK: return self.operation.results[0]
571571
//
572572
// CHECK: @builtins.property
573-
// CHECK: def f64(self):
573+
// CHECK: def f64(self) -> _ods_ir.OpResult:
574574
// CHECK: return self.operation.results[1]
575575
let results = (outs I64:$i64, AnyFloat:$f64);
576576
}
@@ -595,11 +595,11 @@ def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
595595
let regions = (region AnyRegion:$region, AnyRegion, VariadicRegion<AnyRegion>:$variadic);
596596

597597
// CHECK: @builtins.property
598-
// CHECK: def region(self):
598+
// CHECK: def region(self) -> _ods_ir.Region:
599599
// CHECK: return self.regions[0]
600600

601601
// CHECK: @builtins.property
602-
// CHECK: def variadic(self):
602+
// CHECK: def variadic(self) -> _ods_ir.RegionSequence:
603603
// CHECK: return self.regions[2:]
604604
}
605605

@@ -623,7 +623,7 @@ def VariadicRegionOp : TestOp<"variadic_region"> {
623623
let regions = (region VariadicRegion<AnyRegion>:$Variadic);
624624

625625
// CHECK: @builtins.property
626-
// CHECK: def Variadic(self):
626+
// CHECK: def Variadic(self) -> _ods_ir.RegionSequence:
627627
// CHECK: return self.regions[0:]
628628
}
629629

mlir/test/python/dialects/python_test.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,7 @@ def resultTypesDefinedByTraits():
323323
# CHECK: f32 index
324324
print(no_infer.single.type, no_infer.doubled.type)
325325

326+
326327
# CHECK-LABEL: TEST: testOptionalOperandOp
327328
@run
328329
def testOptionalOperandOp():
@@ -615,11 +616,17 @@ def values(lst):
615616
[zero, one], two, [three, four]
616617
)
617618
# CHECK: Value(%{{.*}} = arith.constant 2 : i32)
618-
print(variadic_operands.non_variadic)
619+
non_variadic = variadic_operands.non_variadic
620+
print(non_variadic)
621+
assert isinstance(non_variadic, Value)
619622
# CHECK: ['Value(%{{.*}} = arith.constant 0 : i32)', 'Value(%{{.*}} = arith.constant 1 : i32)']
620-
print(values(variadic_operands.variadic1))
623+
variadic1 = variadic_operands.variadic1
624+
print(values(variadic1))
625+
assert isinstance(variadic1, OpOperandList)
621626
# CHECK: ['Value(%{{.*}} = arith.constant 3 : i32)', 'Value(%{{.*}} = arith.constant 4 : i32)']
622-
print(values(variadic_operands.variadic2))
627+
variadic2 = variadic_operands.variadic2
628+
print(values(variadic2))
629+
assert isinstance(variadic2, OpOperandList)
623630

624631

625632
# CHECK-LABEL: TEST: testVariadicResultAccess
@@ -660,7 +667,9 @@ def types(lst):
660667
# CHECK: i1
661668
print(op.non_variadic2.type)
662669
# CHECK: [IntegerType(i2), IntegerType(i3), IntegerType(i4)]
663-
print(types(op.variadic))
670+
variadic = op.variadic
671+
print(types(variadic))
672+
assert isinstance(variadic, OpResultList)
664673

665674
# Test Variadic-Variadic-Fixed
666675
op = test.SameVariadicResultSizeOpVVF(
@@ -713,3 +722,14 @@ def types(lst):
713722
print(types(op.variadic2))
714723
# CHECK: i4
715724
print(op.non_variadic3.type)
725+
726+
727+
# CHECK-LABEL: TEST: testVariadicAndNormalRegion
728+
@run
729+
def testVariadicAndNormalRegionOp():
730+
with Context() as ctx, Location.unknown(ctx):
731+
module = Module.create()
732+
with InsertionPoint(module.body):
733+
region_op = test.VariadicAndNormalRegionOp(2)
734+
assert isinstance(region_op.region, Region)
735+
assert isinstance(region_op.variadic, RegionSequence)

mlir/test/python/python_test_ops.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,4 +265,9 @@ def SameVariadicResultSizeOpFVFVF : TestOp<"same_variadic_result_fvfvf",
265265
AnyType:$non_variadic3);
266266
}
267267

268+
def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
269+
let regions = (region AnyRegion:$region, VariadicRegion<AnyRegion>:$variadic);
270+
let assemblyFormat = "$region $variadic attr-dict";
271+
}
272+
268273
#endif // PYTHON_TEST_OPS

0 commit comments

Comments
 (0)