From 972ac9725e647e6702c4d2a8b04f972aca0c67f6 Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Thu, 6 Nov 2025 15:43:10 -0800 Subject: [PATCH 1/2] [MLIR][Transform][Python] Sync derived classes and their wrappers Updates the derived Op-classes for the main transform ops to have all the arguments, etc, from the auto-generated classes. Additionally updates and adds missing snake_case wrappers for the derived classes which shadow the snake_case wrappers which were exposed alongside the derived classes. --- .../mlir/dialects/transform/__init__.py | 163 +++++++++++++++++- mlir/test/python/dialects/transform.py | 69 +++++++- 2 files changed, 219 insertions(+), 13 deletions(-) diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py index b075919d1ef0f..de414dc52c0a0 100644 --- a/mlir/python/mlir/dialects/transform/__init__.py +++ b/mlir/python/mlir/dialects/transform/__init__.py @@ -39,16 +39,32 @@ def __init__( super().__init__(result_type, _get_op_result_or_value(target), loc=loc, ip=ip) +def cast( + result_type: Type, target: Union[Operation, Value], *, loc=None, ip=None +) -> OpResult: + return CastOp(result_type=result_type, target=target, loc=loc, ip=ip).result + + @_ods_cext.register_operation(_Dialect, replace=True) class ApplyPatternsOp(ApplyPatternsOp): def __init__( self, target: Union[Operation, Value, OpView], + apply_cse: bool = False, + max_iterations: Optional[Union[IntegerAttr, int]] = None, + max_num_rewrites: Optional[Union[IntegerAttr, int]] = None, *, loc=None, ip=None, ): - super().__init__(target, loc=loc, ip=ip) + super().__init__( + target, + apply_cse=apply_cse, + max_iterations=max_iterations, + max_num_rewrites=max_num_rewrites, + loc=loc, + ip=ip, + ) self.regions[0].blocks.append() @property @@ -56,6 +72,25 @@ def patterns(self) -> Block: return self.regions[0].blocks[0] +def apply_patterns( + target: Union[Operation, Value, OpView], + apply_cse: bool = False, + max_iterations: Optional[Union[IntegerAttr, int]] = None, + max_num_rewrites: Optional[Union[IntegerAttr, int]] = None, + *, + loc=None, + ip=None, +) -> ApplyPatternsOp: + return ApplyPatternsOp( + target=target, + apply_cse=apply_cse, + max_iterations=max_iterations, + max_num_rewrites=max_num_rewrites, + loc=loc, + ip=ip, + ) + + @_ods_cext.register_operation(_Dialect, replace=True) class GetParentOp(GetParentOp): def __init__( @@ -64,6 +99,7 @@ def __init__( target: Union[Operation, Value], *, isolated_from_above: bool = False, + allow_empty_results: bool = False, op_name: Optional[str] = None, deduplicate: bool = False, nth_parent: int = 1, @@ -74,6 +110,7 @@ def __init__( result_type, _get_op_result_or_value(target), isolated_from_above=isolated_from_above, + allow_empty_results=allow_empty_results, op_name=op_name, deduplicate=deduplicate, nth_parent=nth_parent, @@ -82,6 +119,31 @@ def __init__( ) +def get_parent_op( + result_type: Type, + target: Union[Operation, Value], + *, + isolated_from_above: bool = False, + allow_empty_results: bool = False, + op_name: Optional[str] = None, + deduplicate: bool = False, + nth_parent: int = 1, + loc=None, + ip=None, +) -> OpResult: + return GetParentOp( + result_type=result_type, + target=target, + isolated_from_above=isolated_from_above, + allow_empty_results=allow_empty_results, + op_name=op_name, + deduplicate=deduplicate, + nth_parent=nth_parent, + loc=loc, + ip=ip, + ).result + + @_ods_cext.register_operation(_Dialect, replace=True) class MergeHandlesOp(MergeHandlesOp): def __init__( @@ -89,17 +151,32 @@ def __init__( handles: Sequence[Union[Operation, Value]], *, deduplicate: bool = False, + results: Optional[Sequence[Type]] = None, loc=None, ip=None, ): super().__init__( [_get_op_result_or_value(h) for h in handles], deduplicate=deduplicate, + results=results, loc=loc, ip=ip, ) +def merge_handles( + handles: Sequence[Union[Operation, Value]], + *, + deduplicate: bool = False, + results: Optional[Sequence[Type]] = None, + loc=None, + ip=None, +) -> OpResult: + return MergeHandlesOp( + handles=handles, deduplicate=deduplicate, results=results, loc=loc, ip=ip + ).result + + @_ods_cext.register_operation(_Dialect, replace=True) class ReplicateOp(ReplicateOp): def __init__( @@ -119,16 +196,31 @@ def __init__( ) +def replicate( + pattern: Union[Operation, Value], + handles: Sequence[Union[Operation, Value]], + *, + loc=None, + ip=None, +) -> Union[OpResult, OpResultList, ReplicateOp]: + op = ReplicateOp(pattern=pattern, handles=handles, loc=loc, ip=ip) + results = op.results + return results if len(results) > 1 else (results[0] if len(results) == 1 else op) + + @_ods_cext.register_operation(_Dialect, replace=True) class SequenceOp(SequenceOp): def __init__( self, - failure_propagation_mode, + failure_propagation_mode: FailurePropagationMode, results: Sequence[Type], target: Union[Operation, Value, Type], extra_bindings: Optional[ Union[Sequence[Value], Sequence[Type], Operation, OpView] ] = None, + *, + loc=None, + ip=None, ): root = ( _get_op_result_or_value(target) @@ -155,6 +247,8 @@ def __init__( failure_propagation_mode=failure_propagation_mode, root=root, extra_bindings=extra_bindings, + loc=loc, + ip=ip, ) self.regions[0].blocks.append(*tuple([root_type] + extra_binding_types)) @@ -171,16 +265,42 @@ def bodyExtraArgs(self) -> BlockArgumentList: return self.body.arguments[1:] +def sequence( + failure_propagation_mode: FailurePropagationMode, + results: Sequence[Type], + target: Union[Operation, Value, Type], + extra_bindings: Optional[ + Union[Sequence[Value], Sequence[Type], Operation, OpView] + ] = None, + *, + loc=None, + ip=None, +) -> Union[OpResult, OpResultList, SequenceOp]: + op = SequenceOp( + results=results, + failure_propagation_mode=failure_propagation_mode, + extra_bindings=extra_bindings, + target=target, + loc=loc, + ip=ip, + ) + results = op.results + return results if len(results) > 1 else (results[0] if len(results) == 1 else op) + + @_ods_cext.register_operation(_Dialect, replace=True) class NamedSequenceOp(NamedSequenceOp): def __init__( self, - sym_name, + sym_name: Union[str, SymbolRefAttr], input_types: Sequence[Type], result_types: Sequence[Type], - sym_visibility=None, - arg_attrs=None, - res_attrs=None, + *, + sym_visibility: Optional[Union[str, StringAttr]] = None, + arg_attrs: Optional[Union[Sequence[dict], "DictArrayAttr"]] = None, + res_attrs: Optional[Union[Sequence[dict], "DictArrayAttr"]] = None, + loc=None, + ip=None, ): function_type = FunctionType.get(input_types, result_types) super().__init__( @@ -205,6 +325,29 @@ def bodyExtraArgs(self) -> BlockArgumentList: return self.body.arguments[1:] +def named_sequence( + sym_name: Union[str, SymbolRefAttr], + input_types: Sequence[Type], + result_types: Sequence[Type], + *, + sym_visibility: Optional[Union[str, StringAttr]] = None, + arg_attrs: Optional[Union[Sequence[dict], "DictArrayAttr"]] = None, + res_attrs: Optional[Union[Sequence[dict], "DictArrayAttr"]] = None, + loc=None, + ip=None, +) -> NamedSequenceOp: + return NamedSequenceOp( + sym_name=sym_name, + input_types=input_types, + result_types=result_types, + sym_visibility=sym_visibility, + arg_attrs=arg_attrs, + res_attrs=res_attrs, + loc=loc, + ip=ip, + ) + + @_ods_cext.register_operation(_Dialect, replace=True) class YieldOp(YieldOp): def __init__( @@ -219,6 +362,12 @@ def __init__( super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip) +def yield_( + operands: Optional[Union[Operation, Sequence[Value]]] = None, *, loc=None, ip=None +) -> YieldOp: + return YieldOp(operands=operands, loc=loc, ip=ip) + + OptionValueTypes = Union[ Sequence["OptionValueTypes"], Attribute, Value, Operation, OpView, str, int, bool ] @@ -247,7 +396,7 @@ def __init__( def option_value_to_attr(value): nonlocal cur_param_operand_idx if isinstance(value, (Value, Operation, OpView)): - dynamic_options.append(_get_op_result_or_value(value)) + dynamic_options.append(value) cur_param_operand_idx += 1 return ParamOperandAttr(cur_param_operand_idx - 1, context) elif isinstance(value, Attribute): diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py index 6c5e4e5505b1c..97850a851f55b 100644 --- a/mlir/test/python/dialects/transform.py +++ b/mlir/test/python/dialects/transform.py @@ -43,6 +43,26 @@ def testTypes(module: Module): print(param.type) +@run +def testSequenceOp(module: Module): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [transform.AnyOpType.get()], + transform.AnyOpType.get(), + ) + with InsertionPoint(sequence.body): + res = transform.CastOp(transform.AnyOpType.get(), sequence.bodyTarget) + res2 = transform.cast(transform.any_op_t(), res.result) + transform.YieldOp([res2]) + # CHECK-LABEL: TEST: testSequenceOp + # CHECK: transform.sequence + # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op): + # CHECK: %[[RES:.+]] = cast %[[ARG0]] : !transform.any_op to !transform.any_op + # CHECK: %[[RES2:.+]] = cast %[[RES]] : !transform.any_op to !transform.any_op + # CHECK: yield %[[RES2]] : !transform.any_op + # CHECK: } + + @run def testSequenceOp(module: Module): sequence = transform.SequenceOp( @@ -103,7 +123,16 @@ def testSequenceOpWithExtras(module: Module): # CHECK-LABEL: TEST: testSequenceOpWithExtras # CHECK: transform.sequence failures(propagate) # CHECK: ^{{.*}}(%{{.*}}: !transform.any_op, %{{.*}}: !transform.any_op, %{{.*}}: !transform.op<"foo.bar">): - + sequence = transform.sequence( + transform.FailurePropagationMode.Propagate, + [], + transform.AnyOpType.get(), + [transform.AnyOpType.get(), transform.OperationType.get("foo.bar")], + ) + with InsertionPoint(sequence.body): + transform.yield_() + # CHECK: transform.sequence failures(propagate) + # CHECK: ^{{.*}}(%{{.*}}: !transform.any_op, %{{.*}}: !transform.any_op, %{{.*}}: !transform.op<"foo.bar">): @run def testNestedSequenceOpWithExtras(module: Module): @@ -166,8 +195,17 @@ def testNamedSequenceOp(module: Module): transform.YieldOp([named_sequence.bodyTarget]) # CHECK-LABEL: TEST: testNamedSequenceOp # CHECK: module attributes {transform.with_named_sequence} { - # CHECK: transform.named_sequence @__transform_main(%[[ARG0:.+]]: !transform.any_op {transform.consumed}) -> !transform.any_op { - # CHECK: yield %[[ARG0]] : !transform.any_op + # CHECK: transform.named_sequence @__transform_main(%[[ARG0:.+]]: !transform.any_op {transform.consumed}) -> !transform.any_op { + # CHECK: yield %[[ARG0]] : !transform.any_op + named_sequence = transform.named_sequence( + "other_seq", + [transform.AnyOpType.get()], + [transform.AnyOpType.get()], + arg_attrs = [{"transform.consumed": UnitAttr.get()}]) + with InsertionPoint(named_sequence.body): + transform.yield_([named_sequence.bodyTarget]) + # CHECK: transform.named_sequence @other_seq(%[[ARG1:.+]]: !transform.any_op {transform.consumed}) -> !transform.any_op { + # CHECK: yield %[[ARG1]] : !transform.any_op @run @@ -182,11 +220,21 @@ def testGetParentOp(module: Module): isolated_from_above=True, nth_parent=2, ) + transform.get_parent_op( + transform.AnyOpType.get(), + sequence.bodyTarget, + isolated_from_above=True, + nth_parent=2, + allow_empty_results=True, + op_name="func.func", + deduplicate=True, + ) transform.YieldOp() # CHECK-LABEL: TEST: testGetParentOp # CHECK: transform.sequence # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op): # CHECK: = get_parent_op %[[ARG1]] {isolated_from_above, nth_parent = 2 : i64} + # CHECK: = get_parent_op %[[ARG1]] {allow_empty_results, deduplicate, isolated_from_above, nth_parent = 2 : i64, op_name = "func.func"} @run @@ -195,12 +243,14 @@ def testMergeHandlesOp(module: Module): transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() ) with InsertionPoint(sequence.body): - transform.MergeHandlesOp([sequence.bodyTarget]) + res = transform.MergeHandlesOp([sequence.bodyTarget]) + transform.merge_handles([res.result], deduplicate=True) transform.YieldOp() # CHECK-LABEL: TEST: testMergeHandlesOp # CHECK: transform.sequence # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op): - # CHECK: = merge_handles %[[ARG1]] + # CHECK: %[[RES1:.+]] = merge_handles %[[ARG1]] : !transform.any_op + # CHECK: = merge_handles deduplicate %[[RES1]] : !transform.any_op @run @@ -211,11 +261,16 @@ def testApplyPatternsOpCompact(module: Module): with InsertionPoint(sequence.body): with InsertionPoint(transform.ApplyPatternsOp(sequence.bodyTarget).patterns): transform.ApplyCanonicalizationPatternsOp() + with InsertionPoint(transform.apply_patterns(sequence.bodyTarget, apply_cse=True, max_iterations=3, max_num_rewrites=5).patterns): + transform.ApplyCanonicalizationPatternsOp() transform.YieldOp() # CHECK-LABEL: TEST: testApplyPatternsOpCompact # CHECK: apply_patterns to # CHECK: transform.apply_patterns.canonicalization - # CHECK: !transform.any_op + # CHECK: } : !transform.any_op + # CHECK: apply_patterns to + # CHECK: transform.apply_patterns.canonicalization + # CHECK: } {apply_cse, max_iterations = 3 : i64, max_num_rewrites = 5 : i64} : !transform.any_op @run @@ -249,11 +304,13 @@ def testReplicateOp(module: Module): transform.AnyOpType.get(), sequence.bodyTarget, "second" ) transform.ReplicateOp(m1, [m2]) + transform.replicate(m1, [m2]) transform.YieldOp() # CHECK-LABEL: TEST: testReplicateOp # CHECK: %[[FIRST:.+]] = pdl_match # CHECK: %[[SECOND:.+]] = pdl_match # CHECK: %{{.*}} = replicate num(%[[FIRST]]) %[[SECOND]] + # CHECK: %{{.*}} = replicate num(%[[FIRST]]) %[[SECOND]] # CHECK-LABEL: TEST: testApplyRegisteredPassOp From 3b1ad3bafa910b5fe46c64355e93b60c31f44f44 Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Thu, 6 Nov 2025 16:01:23 -0800 Subject: [PATCH 2/2] Formatting --- mlir/test/python/dialects/transform.py | 182 +++++++++++++------------ 1 file changed, 97 insertions(+), 85 deletions(-) diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py index 97850a851f55b..f58442d04fc66 100644 --- a/mlir/test/python/dialects/transform.py +++ b/mlir/test/python/dialects/transform.py @@ -78,6 +78,7 @@ def testSequenceOp(module: Module): # CHECK: yield %[[ARG0]] : !transform.any_op # CHECK: } + @run def testNestedSequenceOp(module: Module): sequence = transform.SequenceOp( @@ -134,53 +135,54 @@ def testSequenceOpWithExtras(module: Module): # CHECK: transform.sequence failures(propagate) # CHECK: ^{{.*}}(%{{.*}}: !transform.any_op, %{{.*}}: !transform.any_op, %{{.*}}: !transform.op<"foo.bar">): + @run def testNestedSequenceOpWithExtras(module: Module): - sequence = transform.SequenceOp( + sequence = transform.SequenceOp( transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get(), [transform.AnyOpType.get(), transform.OperationType.get("foo.bar")], ) - with InsertionPoint(sequence.body): - nested = transform.SequenceOp( + with InsertionPoint(sequence.body): + nested = transform.SequenceOp( transform.FailurePropagationMode.Propagate, [], sequence.bodyTarget, sequence.bodyExtraArgs, ) - with InsertionPoint(nested.body): - transform.YieldOp() - transform.YieldOp() - # CHECK-LABEL: TEST: testNestedSequenceOpWithExtras - # CHECK: transform.sequence failures(propagate) - # CHECK: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op, %[[ARG1:.*]]: !transform.any_op, %[[ARG2:.*]]: !transform.op<"foo.bar">): - # CHECK: sequence %[[ARG0]], %[[ARG1]], %[[ARG2]] : (!transform.any_op, !transform.any_op, !transform.op<"foo.bar">) + with InsertionPoint(nested.body): + transform.YieldOp() + transform.YieldOp() + # CHECK-LABEL: TEST: testNestedSequenceOpWithExtras + # CHECK: transform.sequence failures(propagate) + # CHECK: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op, %[[ARG1:.*]]: !transform.any_op, %[[ARG2:.*]]: !transform.op<"foo.bar">): + # CHECK: sequence %[[ARG0]], %[[ARG1]], %[[ARG2]] : (!transform.any_op, !transform.any_op, !transform.op<"foo.bar">) @run def testTransformPDLOps(module: Module): - withPdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get()) - with InsertionPoint(withPdl.body): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, - [transform.AnyOpType.get()], - withPdl.bodyTarget, - ) - with InsertionPoint(sequence.body): - match = transform_pdl.PDLMatchOp( - transform.AnyOpType.get(), sequence.bodyTarget, "pdl_matcher" - ) - transform.YieldOp(match) - # CHECK-LABEL: TEST: testTransformPDLOps - # CHECK: transform.with_pdl_patterns { - # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op): - # CHECK: = sequence %[[ARG0]] : !transform.any_op -> !transform.any_op failures(propagate) { - # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op): - # CHECK: %[[RES:.+]] = pdl_match @pdl_matcher in %[[ARG1]] - # CHECK: yield %[[RES]] : !transform.any_op - # CHECK: } - # CHECK: } + withPdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get()) + with InsertionPoint(withPdl.body): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [transform.AnyOpType.get()], + withPdl.bodyTarget, + ) + with InsertionPoint(sequence.body): + match = transform_pdl.PDLMatchOp( + transform.AnyOpType.get(), sequence.bodyTarget, "pdl_matcher" + ) + transform.YieldOp(match) + # CHECK-LABEL: TEST: testTransformPDLOps + # CHECK: transform.with_pdl_patterns { + # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op): + # CHECK: = sequence %[[ARG0]] : !transform.any_op -> !transform.any_op failures(propagate) { + # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op): + # CHECK: %[[RES:.+]] = pdl_match @pdl_matcher in %[[ARG1]] + # CHECK: yield %[[RES]] : !transform.any_op + # CHECK: } + # CHECK: } @run @@ -190,7 +192,8 @@ def testNamedSequenceOp(module: Module): "__transform_main", [transform.AnyOpType.get()], [transform.AnyOpType.get()], - arg_attrs = [{"transform.consumed": UnitAttr.get()}]) + arg_attrs=[{"transform.consumed": UnitAttr.get()}], + ) with InsertionPoint(named_sequence.body): transform.YieldOp([named_sequence.bodyTarget]) # CHECK-LABEL: TEST: testNamedSequenceOp @@ -201,7 +204,8 @@ def testNamedSequenceOp(module: Module): "other_seq", [transform.AnyOpType.get()], [transform.AnyOpType.get()], - arg_attrs = [{"transform.consumed": UnitAttr.get()}]) + arg_attrs=[{"transform.consumed": UnitAttr.get()}], + ) with InsertionPoint(named_sequence.body): transform.yield_([named_sequence.bodyTarget]) # CHECK: transform.named_sequence @other_seq(%[[ARG1:.+]]: !transform.any_op {transform.consumed}) -> !transform.any_op { @@ -210,31 +214,31 @@ def testNamedSequenceOp(module: Module): @run def testGetParentOp(module: Module): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() - ) - with InsertionPoint(sequence.body): - transform.GetParentOp( - transform.AnyOpType.get(), - sequence.bodyTarget, - isolated_from_above=True, - nth_parent=2, - ) - transform.get_parent_op( - transform.AnyOpType.get(), - sequence.bodyTarget, - isolated_from_above=True, - nth_parent=2, - allow_empty_results=True, - op_name="func.func", - deduplicate=True, + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() ) - transform.YieldOp() - # CHECK-LABEL: TEST: testGetParentOp - # CHECK: transform.sequence - # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op): - # CHECK: = get_parent_op %[[ARG1]] {isolated_from_above, nth_parent = 2 : i64} - # CHECK: = get_parent_op %[[ARG1]] {allow_empty_results, deduplicate, isolated_from_above, nth_parent = 2 : i64, op_name = "func.func"} + with InsertionPoint(sequence.body): + transform.GetParentOp( + transform.AnyOpType.get(), + sequence.bodyTarget, + isolated_from_above=True, + nth_parent=2, + ) + transform.get_parent_op( + transform.AnyOpType.get(), + sequence.bodyTarget, + isolated_from_above=True, + nth_parent=2, + allow_empty_results=True, + op_name="func.func", + deduplicate=True, + ) + transform.YieldOp() + # CHECK-LABEL: TEST: testGetParentOp + # CHECK: transform.sequence + # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op): + # CHECK: = get_parent_op %[[ARG1]] {isolated_from_above, nth_parent = 2 : i64} + # CHECK: = get_parent_op %[[ARG1]] {allow_empty_results, deduplicate, isolated_from_above, nth_parent = 2 : i64, op_name = "func.func"} @run @@ -255,38 +259,46 @@ def testMergeHandlesOp(module: Module): @run def testApplyPatternsOpCompact(module: Module): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() - ) - with InsertionPoint(sequence.body): - with InsertionPoint(transform.ApplyPatternsOp(sequence.bodyTarget).patterns): - transform.ApplyCanonicalizationPatternsOp() - with InsertionPoint(transform.apply_patterns(sequence.bodyTarget, apply_cse=True, max_iterations=3, max_num_rewrites=5).patterns): - transform.ApplyCanonicalizationPatternsOp() - transform.YieldOp() - # CHECK-LABEL: TEST: testApplyPatternsOpCompact - # CHECK: apply_patterns to - # CHECK: transform.apply_patterns.canonicalization - # CHECK: } : !transform.any_op - # CHECK: apply_patterns to - # CHECK: transform.apply_patterns.canonicalization - # CHECK: } {apply_cse, max_iterations = 3 : i64, max_num_rewrites = 5 : i64} : !transform.any_op + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() + ) + with InsertionPoint(sequence.body): + with InsertionPoint(transform.ApplyPatternsOp(sequence.bodyTarget).patterns): + transform.ApplyCanonicalizationPatternsOp() + with InsertionPoint( + transform.apply_patterns( + sequence.bodyTarget, + apply_cse=True, + max_iterations=3, + max_num_rewrites=5, + ).patterns + ): + transform.ApplyCanonicalizationPatternsOp() + transform.YieldOp() + # CHECK-LABEL: TEST: testApplyPatternsOpCompact + # CHECK: apply_patterns to + # CHECK: transform.apply_patterns.canonicalization + # CHECK: } : !transform.any_op + # CHECK: apply_patterns to + # CHECK: transform.apply_patterns.canonicalization + # CHECK: } {apply_cse, max_iterations = 3 : i64, max_num_rewrites = 5 : i64} : !transform.any_op @run def testApplyPatternsOpWithType(module: Module): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, [], - transform.OperationType.get('test.dummy') - ) - with InsertionPoint(sequence.body): - with InsertionPoint(transform.ApplyPatternsOp(sequence.bodyTarget).patterns): - transform.ApplyCanonicalizationPatternsOp() - transform.YieldOp() - # CHECK-LABEL: TEST: testApplyPatternsOp - # CHECK: apply_patterns to - # CHECK: transform.apply_patterns.canonicalization - # CHECK: !transform.op<"test.dummy"> + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("test.dummy"), + ) + with InsertionPoint(sequence.body): + with InsertionPoint(transform.ApplyPatternsOp(sequence.bodyTarget).patterns): + transform.ApplyCanonicalizationPatternsOp() + transform.YieldOp() + # CHECK-LABEL: TEST: testApplyPatternsOp + # CHECK: apply_patterns to + # CHECK: transform.apply_patterns.canonicalization + # CHECK: !transform.op<"test.dummy"> @run