diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py index de414dc52c0a0..b3dd79c7dbd79 100644 --- a/mlir/python/mlir/dialects/transform/__init__.py +++ b/mlir/python/mlir/dialects/transform/__init__.py @@ -7,6 +7,7 @@ from .._transform_ops_gen import _Dialect from ..._mlir_libs._mlirDialectsTransform import * from ..._mlir_libs._mlirDialectsTransform import AnyOpType, OperationType +from . import interpreter try: from ...ir import * @@ -324,6 +325,25 @@ def bodyTarget(self) -> Value: def bodyExtraArgs(self) -> BlockArgumentList: return self.body.arguments[1:] + def apply( + self, + payload: Module, + transform_options: Optional[interpreter.TransformOptions] = None, + ) -> Module: + assert self.parent + assert "transform.with_named_sequence" in self.parent.attributes + assert isinstance( + self.parent.attributes["transform.with_named_sequence"], UnitAttr + ) + + interpreter.apply_named_sequence( + payload_root=payload, + transform_root=self, + transform_module=self.parent, + transform_options=transform_options, + ) + return payload # NB: was modified in-place (if any transformation happened) + def named_sequence( sym_name: Union[str, SymbolRefAttr], diff --git a/mlir/test/python/dialects/transform_interpreter.py b/mlir/test/python/dialects/transform_interpreter.py index 819a3be1db9d5..ca9ce5dbd23c1 100644 --- a/mlir/test/python/dialects/transform_interpreter.py +++ b/mlir/test/python/dialects/transform_interpreter.py @@ -31,6 +31,20 @@ def print_self(): # CHECK: transform.yield +@test_in_context +def print_self_via_apply_method(): + m = ir.Module.parse( + print_root_module.replace("from interpreter", "print_self_via_apply_method") + ) + m.body.operations[0].apply(m) + + +# CHECK-LABEL: print_self_via_apply_method +# CHECK: transform.named_sequence @__transform_main +# CHECK: transform.print +# CHECK: transform.yield + + @test_in_context def print_other(): transform = ir.Module.parse(