diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py index 6ce77b4cb93f6..32f46d24cc739 100644 --- a/mlir/python/mlir/_mlir_libs/__init__.py +++ b/mlir/python/mlir/_mlir_libs/__init__.py @@ -56,6 +56,21 @@ def get_include_dirs() -> Sequence[str]: # # This facility allows downstreams to customize Context creation to their # needs. + +_dialect_registry = None + + +def get_dialect_registry(): + global _dialect_registry + + if _dialect_registry is None: + from ._mlir import ir + + _dialect_registry = ir.DialectRegistry() + + return _dialect_registry + + def _site_initialize(): import importlib import itertools @@ -63,7 +78,6 @@ def _site_initialize(): from ._mlir import ir logger = logging.getLogger(__name__) - registry = ir.DialectRegistry() post_init_hooks = [] disable_multithreading = False @@ -84,7 +98,7 @@ def process_initializer_module(module_name): logger.debug("Initializing MLIR with module: %s", module_name) if hasattr(m, "register_dialects"): logger.debug("Registering dialects from initializer %r", m) - m.register_dialects(registry) + m.register_dialects(get_dialect_registry()) if hasattr(m, "context_init_hook"): logger.debug("Adding context init hook from %r", m) post_init_hooks.append(m.context_init_hook) @@ -110,7 +124,7 @@ def process_initializer_module(module_name): class Context(ir._BaseContext): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.append_dialect_registry(registry) + self.append_dialect_registry(get_dialect_registry()) for hook in post_init_hooks: hook(self) if not disable_multithreading: diff --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py index 6579e02d8549e..b5baa80bc767f 100644 --- a/mlir/python/mlir/dialects/python_test.py +++ b/mlir/python/mlir/dialects/python_test.py @@ -11,7 +11,7 @@ ) -def register_python_test_dialect(context, load=True): +def register_python_test_dialect(registry): from .._mlir_libs import _mlirPythonTest - _mlirPythonTest.register_python_test_dialect(context, load) + _mlirPythonTest.register_dialect(registry) diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py index 18526ab8c3c02..6d21da3b4179f 100644 --- a/mlir/python/mlir/ir.py +++ b/mlir/python/mlir/ir.py @@ -5,6 +5,7 @@ from ._mlir_libs._mlir.ir import * from ._mlir_libs._mlir.ir import _GlobalDebug from ._mlir_libs._mlir import register_type_caster, register_value_caster +from ._mlir_libs import get_dialect_registry # Convenience decorator for registering user-friendly Attribute builders. diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py index f313a400b73c0..88761c9d08fe0 100644 --- a/mlir/test/python/dialects/python_test.py +++ b/mlir/test/python/dialects/python_test.py @@ -6,6 +6,8 @@ import mlir.dialects.tensor as tensor import mlir.dialects.arith as arith +test.register_python_test_dialect(get_dialect_registry()) + def run(f): print("\nTEST:", f.__name__) @@ -17,7 +19,6 @@ def run(f): @run def testAttributes(): with Context() as ctx, Location.unknown(): - test.register_python_test_dialect(ctx) # # Check op construction with attributes. # @@ -138,7 +139,6 @@ def testAttributes(): @run def attrBuilder(): with Context() as ctx, Location.unknown(): - test.register_python_test_dialect(ctx) # CHECK: python_test.attributes_op op = test.AttributesOp( # CHECK-DAG: x_affinemap = affine_map<() -> (2)> @@ -215,7 +215,6 @@ def attrBuilder(): @run def inferReturnTypes(): with Context() as ctx, Location.unknown(ctx): - test.register_python_test_dialect(ctx) module = Module.create() with InsertionPoint(module.body): op = test.InferResultsOp() @@ -260,7 +259,6 @@ def inferReturnTypes(): @run def resultTypesDefinedByTraits(): with Context() as ctx, Location.unknown(ctx): - test.register_python_test_dialect(ctx) module = Module.create() with InsertionPoint(module.body): inferred = test.InferResultsOp() @@ -295,8 +293,6 @@ def resultTypesDefinedByTraits(): @run def testOptionalOperandOp(): with Context() as ctx, Location.unknown(): - test.register_python_test_dialect(ctx) - module = Module.create() with InsertionPoint(module.body): op1 = test.OptionalOperandOp() @@ -312,7 +308,6 @@ def testOptionalOperandOp(): @run def testCustomAttribute(): with Context() as ctx: - test.register_python_test_dialect(ctx) a = test.TestAttr.get() # CHECK: #python_test.test_attr print(a) @@ -350,7 +345,6 @@ def testCustomAttribute(): @run def testCustomType(): with Context() as ctx: - test.register_python_test_dialect(ctx) a = test.TestType.get() # CHECK: !python_test.test_type print(a) @@ -397,8 +391,6 @@ def testCustomType(): # CHECK-LABEL: TEST: testTensorValue def testTensorValue(): with Context() as ctx, Location.unknown(): - test.register_python_test_dialect(ctx) - i8 = IntegerType.get_signless(8) class Tensor(test.TestTensorValue): @@ -436,7 +428,6 @@ def __str__(self): @run def inferReturnTypeComponents(): with Context() as ctx, Location.unknown(ctx): - test.register_python_test_dialect(ctx) module = Module.create() i32 = IntegerType.get_signless(32) with InsertionPoint(module.body): @@ -488,8 +479,6 @@ def inferReturnTypeComponents(): @run def testCustomTypeTypeCaster(): with Context() as ctx, Location.unknown(): - test.register_python_test_dialect(ctx) - a = test.TestType.get() assert a.typeid is not None @@ -542,7 +531,6 @@ def type_caster(pytype): @run def testInferTypeOpInterface(): with Context() as ctx, Location.unknown(ctx): - test.register_python_test_dialect(ctx) module = Module.create() with InsertionPoint(module.body): i64 = IntegerType.get_signless(64) diff --git a/mlir/test/python/lib/PythonTestModule.cpp b/mlir/test/python/lib/PythonTestModule.cpp index aff414894cb82..f81b851f8759b 100644 --- a/mlir/test/python/lib/PythonTestModule.cpp +++ b/mlir/test/python/lib/PythonTestModule.cpp @@ -34,6 +34,15 @@ PYBIND11_MODULE(_mlirPythonTest, m) { }, py::arg("context"), py::arg("load") = true); + m.def( + "register_dialect", + [](MlirDialectRegistry registry) { + MlirDialectHandle pythonTestDialect = + mlirGetDialectHandle__python_test__(); + mlirDialectHandleInsertDialect(pythonTestDialect, registry); + }, + py::arg("registry")); + mlir_attribute_subclass(m, "TestAttr", mlirAttributeIsAPythonTestTestAttribute) .def_classmethod(