Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions mlir/python/mlir/dialects/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .._transform_ops_gen import _Dialect
from ..._mlir_libs._mlirDialectsTransform import *
from ..._mlir_libs._mlirDialectsTransform import AnyOpType, OperationType
from . import interpreter

try:
from ...ir import *
Expand Down Expand Up @@ -324,6 +325,25 @@ def bodyTarget(self) -> Value:
def bodyExtraArgs(self) -> BlockArgumentList:
return self.body.arguments[1:]

def apply(
self,
payload: Module,
transform_options: Optional[interpreter.TransformOptions] = None,
) -> Module:
assert self.parent
assert "transform.with_named_sequence" in self.parent.attributes
assert isinstance(
self.parent.attributes["transform.with_named_sequence"], UnitAttr
)

interpreter.apply_named_sequence(
payload_root=payload,
transform_root=self,
transform_module=self.parent,
transform_options=transform_options,
)
return payload # NB: was modified in-place (if any transformation happened)


def named_sequence(
sym_name: Union[str, SymbolRefAttr],
Expand Down
14 changes: 14 additions & 0 deletions mlir/test/python/dialects/transform_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,20 @@ def print_self():
# CHECK: transform.yield


@test_in_context
def print_self_via_apply_method():
m = ir.Module.parse(
print_root_module.replace("from interpreter", "print_self_via_apply_method")
)
m.body.operations[0].apply(m)


# CHECK-LABEL: print_self_via_apply_method
# CHECK: transform.named_sequence @__transform_main
# CHECK: transform.print
# CHECK: transform.yield


@test_in_context
def print_other():
transform = ir.Module.parse(
Expand Down
Loading