diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 061d7620ba077..c464e4da66f17 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -634,6 +634,10 @@ MLIR_CAPI_EXPORTED MlirContext mlirOperationGetContext(MlirOperation op); /// Gets the location of the operation. MLIR_CAPI_EXPORTED MlirLocation mlirOperationGetLocation(MlirOperation op); +/// Sets the location of the operation. +MLIR_CAPI_EXPORTED void mlirOperationSetLocation(MlirOperation op, + MlirLocation loc); + /// Gets the type id of the operation. /// Returns null if the operation does not have a registered operation /// description. diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 83a8757bb72c7..c20b2111c071e 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -3485,15 +3485,21 @@ void mlir::python::populateIRCore(nb::module_ &m) { }, "Shortcut to get an op result if it has only one (throws an error " "otherwise).") - .def_prop_ro( + .def_prop_rw( "location", [](PyOperationBase &self) { PyOperation &operation = self.getOperation(); return PyLocation(operation.getContext(), mlirOperationGetLocation(operation.get())); }, - "Returns the source location the operation was defined or derived " - "from.") + [](PyOperationBase &self, const PyLocation &location) { + PyOperation &operation = self.getOperation(); + mlirOperationSetLocation(operation.get(), location.get()); + }, + nb::for_getter("Returns the source location the operation was " + "defined or derived from."), + nb::for_setter("Sets the source location the operation was defined " + "or derived from.")) .def_prop_ro("parent", [](PyOperationBase &self) -> std::optional> { diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index e9844a7cc1909..188186598c5c5 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -656,6 +656,10 @@ MlirLocation mlirOperationGetLocation(MlirOperation op) { return wrap(unwrap(op)->getLoc()); } +void mlirOperationSetLocation(MlirOperation op, MlirLocation loc) { + unwrap(op)->setLoc(unwrap(loc)); +} + MlirTypeID mlirOperationGetTypeID(MlirOperation op) { if (auto info = unwrap(op)->getRegisteredInfo()) return wrap(info->getTypeID()); diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py index 4a3625c953d52..cb4cfc8c8a6ec 100644 --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -696,6 +696,7 @@ def testOperationPrint(): # CHECK: resource1: "0x08 module.operation.print(large_elements_limit=2) + # CHECK-LABEL: TEST: testKnownOpView @run def testKnownOpView(): @@ -969,6 +970,13 @@ def testOperationLoc(): assert op.location == loc assert op.operation.location == loc + another_loc = Location.name("another_loc") + op.location = another_loc + assert op.location == another_loc + assert op.operation.location == another_loc + # CHECK: loc("another_loc") + print(op.location) + # CHECK-LABEL: TEST: testModuleMerge @run