diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py index c2bd33b4203fe..5bcc595220f69 100644 --- a/mlir/python/mlir/dialects/ext.py +++ b/mlir/python/mlir/dialects/ext.py @@ -44,12 +44,14 @@ register_dialect = _cext.register_dialect -def register_operation(dialect_cls: type) -> Callable[[type], type]: - register = _cext.register_operation(dialect_cls) +def register_operation( + dialect_cls: type, *, replace: bool = False +) -> Callable[[type], type]: + register = _cext.register_operation(dialect_cls, replace=replace) def decorator(op_cls: type) -> type: register(op_cls) - _cext.register_op_adaptor(op_cls)(op_cls.Adaptor) + _cext.register_op_adaptor(op_cls, replace=replace)(op_cls.Adaptor) return op_cls return decorator @@ -809,7 +811,13 @@ def _emit_module(cls) -> ir.Module: return m @classmethod - def load(cls, register=True, reload=False) -> None: + def load( + cls, + *, + register: bool = True, + reload: bool = False, + replace: bool = False, + ) -> None: if hasattr(cls, "_mlir_module") and not reload: return @@ -825,15 +833,15 @@ def load(cls, register=True, reload=False) -> None: for type_ in cls.types: typeid = ir.DynamicType.lookup_typeid(type_.type_name) - _cext.register_type_caster(typeid)(type_) + _cext.register_type_caster(typeid, replace=replace)(type_) for attr in cls.attributes: typeid = ir.DynamicAttr.lookup_typeid(attr.attr_name) - _cext.register_type_caster(typeid)(attr) + _cext.register_type_caster(typeid, replace=replace)(attr) if register: register_dialect(cls) - register_dialect_operation = register_operation(cls) + register_dialect_operation = register_operation(cls, replace=replace) for op in cls.operations: register_dialect_operation(op)