diff --git a/mlir/include/mlir/Bindings/Python/Globals.h b/mlir/include/mlir/Bindings/Python/Globals.h index 8f7085f6024f5..8a7f30fd218dc 100644 --- a/mlir/include/mlir/Bindings/Python/Globals.h +++ b/mlir/include/mlir/Bindings/Python/Globals.h @@ -78,10 +78,10 @@ class MLIR_PYTHON_API_EXPORTED PyGlobals { bool replace = false); /// Adds a concrete implementation dialect class. - /// Raises an exception if the mapping already exists. + /// Raises an exception if the mapping already exists and replace == false. /// This is intended to be called by implementation code. void registerDialectImpl(const std::string &dialectNamespace, - nanobind::object pyClass); + nanobind::object pyClass, bool replace = false); /// Adds a concrete implementation operation class. /// Raises an exception if the mapping already exists and replace == false. diff --git a/mlir/lib/Bindings/Python/Globals.cpp b/mlir/lib/Bindings/Python/Globals.cpp index 411b8a6705f1c..82195acb9f4fb 100644 --- a/mlir/lib/Bindings/Python/Globals.cpp +++ b/mlir/lib/Bindings/Python/Globals.cpp @@ -130,10 +130,10 @@ void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID, } void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, - nb::object pyClass) { + nb::object pyClass, bool replace) { nb::ft_lock_guard lock(mutex); nb::object &found = dialectClassMap[dialectNamespace]; - if (found) { + if (found && !replace) { throw std::runtime_error(nanobind::detail::join( "Dialect namespace '", dialectNamespace, "' is already registered.")); } diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index b8637c57a3f48..7341e7218c962 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -2791,7 +2791,8 @@ void populateRoot(nb::module_ &m) { }, "dialect_namespace"_a) .def("_register_dialect_impl", &PyGlobals::registerDialectImpl, - "dialect_namespace"_a, "dialect_class"_a, + "dialect_namespace"_a, "dialect_class"_a, nb::kw_only(), + "replace"_a = false, "Testing hook for directly registering a dialect") .def("_register_operation_impl", &PyGlobals::registerOperationImpl, "operation_name"_a, "operation_class"_a, nb::kw_only(), diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py index 5bcc595220f69..45eb218af3448 100644 --- a/mlir/python/mlir/dialects/ext.py +++ b/mlir/python/mlir/dialects/ext.py @@ -33,29 +33,12 @@ "Region", "Type", "Attribute", - "register_dialect", - "register_operation", ] Operand = ir.Value Result = ir.OpResult Region = ir.Region -register_dialect = _cext.register_dialect - - -def register_operation( - dialect_cls: type, *, replace: bool = False -) -> Callable[[type], type]: - register = _cext.register_operation(dialect_cls, replace=replace) - - def decorator(op_cls: type) -> type: - register(op_cls) - _cext.register_op_adaptor(op_cls, replace=replace)(op_cls.Adaptor) - return op_cls - - return decorator - def construct_instance(origin, args): # `origin.get` is to construct an instance of MLIR type or attribute. @@ -814,11 +797,14 @@ def _emit_module(cls) -> ir.Module: def load( cls, *, - register: bool = True, reload: bool = False, - replace: bool = False, ) -> None: if hasattr(cls, "_mlir_module") and not reload: + if cls._mlir_module.context is not ir.Context.current: + raise RuntimeError( + "This dialect was loaded in a different context. " + "Please set reload=True to reload the dialect in the current context." + ) return cls._mlir_module = cls._emit_module() @@ -831,17 +817,16 @@ def load( for op in cls.operations: op._attach_traits() + _cext.globals._register_dialect_impl(cls.DIALECT_NAMESPACE, cls, replace=reload) + for type_ in cls.types: typeid = ir.DynamicType.lookup_typeid(type_.type_name) - _cext.register_type_caster(typeid, replace=replace)(type_) + _cext.register_type_caster(typeid, replace=reload)(type_) for attr in cls.attributes: typeid = ir.DynamicAttr.lookup_typeid(attr.attr_name) - _cext.register_type_caster(typeid, replace=replace)(attr) - - if register: - register_dialect(cls) + _cext.register_type_caster(typeid, replace=reload)(attr) - register_dialect_operation = register_operation(cls, replace=replace) - for op in cls.operations: - register_dialect_operation(op) + for op in cls.operations: + _cext.register_operation(cls, replace=reload)(op) + _cext.register_op_adaptor(op, replace=reload)(op.Adaptor) diff --git a/mlir/test/python/dialects/transform_op_interface.py b/mlir/test/python/dialects/transform_op_interface.py index f58e0be13befd..a6e2c6da45322 100644 --- a/mlir/test/python/dialects/transform_op_interface.py +++ b/mlir/test/python/dialects/transform_op_interface.py @@ -16,7 +16,6 @@ ) -@ext.register_dialect class MyTransform(ext.Dialect, name="my_transform"): pass @@ -26,7 +25,7 @@ def run(emit_schedule): with ir.Context() as ctx, ir.Location.unknown(): payload = emit_payload() - MyTransform.load(register=False, reload=True) + MyTransform.load(reload=True) GetNamedAttributeOp.attach_interface_impls(ctx) PrintParamOp.attach_interface_impls(ctx) @@ -86,7 +85,6 @@ def get_effects(op: ir.Operation, effects): # Demonstration of a TransformOpInterface-implementing op that gets named attributes # from target ops and produces them as param handles. -@ext.register_operation(MyTransform) class GetNamedAttributeOp(MyTransform.Operation, name="get_named_attribute"): target: ext.Operand[transform.AnyOpType] attr_name: ir.StringAttr @@ -120,7 +118,6 @@ def allow_repeated_handle_operands(_op: "GetNamedAttributeOp") -> bool: return False -@ext.register_operation(MyTransform) class PrintParamOp(MyTransform.Operation, name="print_param"): target: ext.Operand[transform.AnyParamType] name: ir.StringAttr @@ -150,7 +147,6 @@ def allow_repeated_handle_operands(_op: "GetNamedAttributeOp") -> bool: # Syntax for an op with one op handle operand and one op handle result. -@ext.register_operation(MyTransform) class OneOpInOneOpOut(MyTransform.Operation, name="one_op_in_one_op_out"): target: ext.Operand[transform.AnyOpType] res: ext.Result[transform.AnyOpType[()]] @@ -273,7 +269,6 @@ def get_effects(op: ir.Operation, effects): return schedule -@ext.register_operation(MyTransform) class OpValParamInParamOpValOut( MyTransform.Operation, name="op_val_param_in_param_op_val_out" ): @@ -378,7 +373,6 @@ def allow_repeated_handle_operands(_op: OpValParamInParamOpValOut) -> bool: return schedule -@ext.register_operation(MyTransform) class OpsParamsInValuesParamOut( MyTransform.Operation, name="ops_params_in_values_param_out" ): diff --git a/mlir/test/python/dialects/transform_pattern_descriptor_op_interface.py b/mlir/test/python/dialects/transform_pattern_descriptor_op_interface.py index 470c679179b03..9cd73331cfdea 100644 --- a/mlir/test/python/dialects/transform_pattern_descriptor_op_interface.py +++ b/mlir/test/python/dialects/transform_pattern_descriptor_op_interface.py @@ -7,7 +7,6 @@ from mlir.dialects.transform import AnyOpType, structured -@ext.register_dialect class MyPatternDescriptors(ext.Dialect, name="my_pattern_descriptors"): pass @@ -17,7 +16,7 @@ def run(emit_schedule): with ir.Context(), ir.Location.unknown(): payload = emit_payload() - MyPatternDescriptors.load(register=False, reload=True) + MyPatternDescriptors.load(reload=True) # NB: Pattern descriptor ops have their interfaces attached # in their respective test functions. @@ -58,7 +57,6 @@ def schedule_boilerplate(): yield schedule, named_sequence -@ext.register_operation(MyPatternDescriptors) class SubiAddiRewritePatternOp(MyPatternDescriptors.Operation, name="add_pattern"): @classmethod def attach_interface_impls(cls, ctx=None):