diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py index 5bcc595220f69..15651a1c4e858 100644 --- a/mlir/python/mlir/dialects/ext.py +++ b/mlir/python/mlir/dialects/ext.py @@ -22,6 +22,7 @@ from ._ods_common import _cext, segmented_accessor from .irdl import Variadicity from ..passmanager import PassManager +from contextlib import nullcontext ir = _cext.ir @@ -804,9 +805,10 @@ def _emit_dialect(cls) -> None: @classmethod def _emit_module(cls) -> ir.Module: - m = ir.Module.create() - with ir.InsertionPoint(m.body): - cls._emit_dialect() + with ir.Location.unknown() if not ir.Location.current else nullcontext(): + m = ir.Module.create() + with ir.InsertionPoint(m.body): + cls._emit_dialect() return m diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py index 30275cc53d096..30132f891faec 100644 --- a/mlir/test/python/dialects/ext.py +++ b/mlir/test/python/dialects/ext.py @@ -43,7 +43,7 @@ class AddOp(Operation, dialect=MyInt, name="add"): # CHECK: irdl.results(res: %0) # CHECK: } # CHECK: } - with Context(), Location.unknown(): + with Context(): MyInt.load() print(MyInt._mlir_module) @@ -51,13 +51,14 @@ class AddOp(Operation, dialect=MyInt, name="add"): print([i._op_name for i in MyInt.operations]) i32 = IntegerType.get_signless(32) - module = Module.create() - with InsertionPoint(module.body): - two = ConstantOp(IntegerAttr.get(i32, 2)) - three = ConstantOp(IntegerAttr.get(i32, 3)) - add1 = AddOp(two, three) - add2 = AddOp(add1, two) - add3 = AddOp(add2, three) + with Location.unknown(): + module = Module.create() + with InsertionPoint(module.body): + two = ConstantOp(IntegerAttr.get(i32, 2)) + three = ConstantOp(IntegerAttr.get(i32, 3)) + add1 = AddOp(two, three) + add2 = AddOp(add1, two) + add3 = AddOp(add2, three) # CHECK: %0 = "myint.constant"() {value = 2 : i32} : () -> i32 # CHECK: %1 = "myint.constant"() {value = 3 : i32} : () -> i32