10 changes: 1 addition & 9 deletions mlir/lib/Bindings/Python/MainModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ PYBIND11_MODULE(_mlir, m) {
"Testing hook for directly registering a dialect")
.def("_register_operation_impl", &PyGlobals::registerOperationImpl,
py::arg("operation_name"), py::arg("operation_class"),
py::arg("raw_opview_class"),
"Testing hook for directly registering an operation");

// Aside from making the globals accessible to python, having python manage
Expand All @@ -68,18 +67,11 @@ PYBIND11_MODULE(_mlir, m) {
[dialectClass](py::object opClass) -> py::object {
std::string operationName =
opClass.attr("OPERATION_NAME").cast<std::string>();
auto rawSubclass = PyOpView::createRawSubclass(opClass);
PyGlobals::get().registerOperationImpl(operationName, opClass,
rawSubclass);
PyGlobals::get().registerOperationImpl(operationName, opClass);

// Dict-stuff the new opClass by name onto the dialect class.
py::object opClassName = opClass.attr("__name__");
dialectClass.attr(opClassName) = opClass;

// Now create a special "Raw" subclass that passes through
// construction to the OpView parent (bypasses the intermediate
// child's __init__).
opClass.attr("_Raw") = rawSubclass;
return opClass;
});
},
Expand Down
12 changes: 6 additions & 6 deletions mlir/lib/Bindings/Python/Pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,16 +116,16 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
"ValueError if the pipeline can't be parsed.")
.def(
"run",
[](PyPassManager &passManager, PyModule &module) {
MlirLogicalResult status =
mlirPassManagerRun(passManager.get(), module.get());
[](PyPassManager &passManager, PyOperationBase &op) {
MlirLogicalResult status = mlirPassManagerRunOnOp(
passManager.get(), op.getOperation().get());
if (mlirLogicalResultIsFailure(status))
throw SetPyError(PyExc_RuntimeError,
"Failure while executing pass pipeline.");
},
py::arg("module"),
"Run the pass manager on the provided module, throw a RuntimeError "
"on failure.")
py::arg("operation"),
"Run the pass manager on the provided operation, throw a "
"RuntimeError on failure.")
.def(
"__str__",
[](PyPassManager &self) {
Expand Down
9 changes: 9 additions & 0 deletions mlir/lib/CAPI/IR/IR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,15 @@ MlirOperation mlirOperationCreate(MlirOperationState *state) {
return result;
}

MlirOperation mlirOperationCreateParse(MlirContext context,
MlirStringRef sourceStr,
MlirStringRef sourceName) {

return wrap(
parseSourceString(unwrap(sourceStr), unwrap(context), unwrap(sourceName))
.release());
}

MlirOperation mlirOperationClone(MlirOperation op) {
return wrap(unwrap(op)->clone());
}
Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/CAPI/IR/Pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ mlirPassManagerGetAsOpPassManager(MlirPassManager passManager) {
return wrap(static_cast<OpPassManager *>(unwrap(passManager)));
}

MlirLogicalResult mlirPassManagerRun(MlirPassManager passManager,
MlirModule module) {
return wrap(unwrap(passManager)->run(unwrap(module)));
MlirLogicalResult mlirPassManagerRunOnOp(MlirPassManager passManager,
MlirOperation op) {
return wrap(unwrap(passManager)->run(unwrap(op)));
}

void mlirPassManagerEnableIRPrinting(MlirPassManager passManager) {
Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Parser/Parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,9 @@ LogicalResult mlir::parseSourceFile(

LogicalResult mlir::parseSourceString(llvm::StringRef sourceStr, Block *block,
const ParserConfig &config,
StringRef sourceName,
LocationAttr *sourceFileLoc) {
auto memBuffer = llvm::MemoryBuffer::getMemBuffer(sourceStr);
auto memBuffer = llvm::MemoryBuffer::getMemBuffer(sourceStr, sourceName);
if (!memBuffer)
return failure();

Expand Down
2 changes: 1 addition & 1 deletion mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ globals: "_Globals"
class _Globals:
dialect_search_modules: List[str]
def _register_dialect_impl(self, dialect_namespace: str, dialect_class: type) -> None: ...
def _register_operation_impl(self, operation_name: str, operation_class: type, raw_opview_class: type) -> None: ...
def _register_operation_impl(self, operation_name: str, operation_class: type) -> None: ...
def append_dialect_search_prefix(self, module_name: str) -> None: ...

def register_dialect(dialect_class: type) -> object: ...
Expand Down
3 changes: 2 additions & 1 deletion mlir/test/CAPI/execution_engine.c
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ void lowerModuleToLLVM(MlirContext ctx, MlirModule module) {
mlirPassManagerAddOwnedPass(pm, mlirCreateConversionConvertFuncToLLVMPass());
mlirOpPassManagerAddOwnedPass(
opm, mlirCreateConversionArithToLLVMConversionPass());
MlirLogicalResult status = mlirPassManagerRun(pm, module);
MlirLogicalResult status =
mlirPassManagerRunOnOp(pm, mlirModuleGetOperation(module));
if (mlirLogicalResultIsFailure(status)) {
fprintf(stderr, "Unexpected failure running pass pipeline\n");
exit(2);
Expand Down
97 changes: 49 additions & 48 deletions mlir/test/CAPI/pass.c
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,16 @@ void testRunPassOnModule(void) {
MlirContext ctx = mlirContextCreate();
registerAllUpstreamDialects(ctx);

MlirModule module = mlirModuleCreateParse(
ctx,
// clang-format off
mlirStringRefCreateFromCString(
"func.func @foo(%arg0 : i32) -> i32 { \n"
" %res = arith.addi %arg0, %arg0 : i32 \n"
" return %res : i32 \n"
"}"));
// clang-format on
if (mlirModuleIsNull(module)) {
fprintf(stderr, "Unexpected failure parsing module.\n");
const char *funcAsm = //
"func.func @foo(%arg0 : i32) -> i32 { \n"
" %res = arith.addi %arg0, %arg0 : i32 \n"
" return %res : i32 \n"
"} \n";
MlirOperation func =
mlirOperationCreateParse(ctx, mlirStringRefCreateFromCString(funcAsm),
mlirStringRefCreateFromCString("funcAsm"));
if (mlirOperationIsNull(func)) {
fprintf(stderr, "Unexpected failure parsing asm.\n");
exit(EXIT_FAILURE);
}

Expand All @@ -56,37 +55,38 @@ void testRunPassOnModule(void) {
MlirPassManager pm = mlirPassManagerCreate(ctx);
MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
mlirPassManagerAddOwnedPass(pm, printOpStatPass);
MlirLogicalResult success = mlirPassManagerRun(pm, module);
MlirLogicalResult success = mlirPassManagerRunOnOp(pm, func);
if (mlirLogicalResultIsFailure(success)) {
fprintf(stderr, "Unexpected failure running pass manager.\n");
exit(EXIT_FAILURE);
}
mlirPassManagerDestroy(pm);
}
mlirModuleDestroy(module);
mlirOperationDestroy(func);
mlirContextDestroy(ctx);
}

void testRunPassOnNestedModule(void) {
MlirContext ctx = mlirContextCreate();
registerAllUpstreamDialects(ctx);

MlirModule module = mlirModuleCreateParse(
ctx,
// clang-format off
mlirStringRefCreateFromCString(
"func.func @foo(%arg0 : i32) -> i32 { \n"
" %res = arith.addi %arg0, %arg0 : i32 \n"
" return %res : i32 \n"
"} \n"
"module { \n"
" func.func @bar(%arg0 : f32) -> f32 { \n"
" %res = arith.addf %arg0, %arg0 : f32 \n"
" return %res : f32 \n"
" } \n"
"}"));
// clang-format on
if (mlirModuleIsNull(module))
const char *moduleAsm = //
"module { \n"
" func.func @foo(%arg0 : i32) -> i32 { \n"
" %res = arith.addi %arg0, %arg0 : i32 \n"
" return %res : i32 \n"
" } \n"
" module { \n"
" func.func @bar(%arg0 : f32) -> f32 { \n"
" %res = arith.addf %arg0, %arg0 : f32 \n"
" return %res : f32 \n"
" } \n"
" } \n"
"} \n";
MlirOperation module =
mlirOperationCreateParse(ctx, mlirStringRefCreateFromCString(moduleAsm),
mlirStringRefCreateFromCString("moduleAsm"));
if (mlirOperationIsNull(module))
exit(1);

// Run the print-op-stats pass on functions under the top-level module:
Expand All @@ -100,7 +100,7 @@ void testRunPassOnNestedModule(void) {
pm, mlirStringRefCreateFromCString("func.func"));
MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass);
MlirLogicalResult success = mlirPassManagerRun(pm, module);
MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
if (mlirLogicalResultIsFailure(success))
exit(2);
mlirPassManagerDestroy(pm);
Expand All @@ -118,13 +118,13 @@ void testRunPassOnNestedModule(void) {
nestedModulePm, mlirStringRefCreateFromCString("func.func"));
MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass);
MlirLogicalResult success = mlirPassManagerRun(pm, module);
MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
if (mlirLogicalResultIsFailure(success))
exit(2);
mlirPassManagerDestroy(pm);
}

mlirModuleDestroy(module);
mlirOperationDestroy(module);
mlirContextDestroy(ctx);
}

Expand Down Expand Up @@ -339,16 +339,17 @@ void testExternalPass(void) {
MlirContext ctx = mlirContextCreate();
registerAllUpstreamDialects(ctx);

MlirModule module = mlirModuleCreateParse(
ctx,
// clang-format off
mlirStringRefCreateFromCString(
"func.func @foo(%arg0 : i32) -> i32 { \n"
" %res = arith.addi %arg0, %arg0 : i32 \n"
" return %res : i32 \n"
"}"));
// clang-format on
if (mlirModuleIsNull(module)) {
const char *moduleAsm = //
"module { \n"
" func.func @foo(%arg0 : i32) -> i32 { \n"
" %res = arith.addi %arg0, %arg0 : i32 \n"
" return %res : i32 \n"
" } \n"
"}";
MlirOperation module =
mlirOperationCreateParse(ctx, mlirStringRefCreateFromCString(moduleAsm),
mlirStringRefCreateFromCString("moduleAsm"));
if (mlirOperationIsNull(module)) {
fprintf(stderr, "Unexpected failure parsing module.\n");
exit(EXIT_FAILURE);
}
Expand Down Expand Up @@ -377,7 +378,7 @@ void testExternalPass(void) {

MlirPassManager pm = mlirPassManagerCreate(ctx);
mlirPassManagerAddOwnedPass(pm, externalPass);
MlirLogicalResult success = mlirPassManagerRun(pm, module);
MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
if (mlirLogicalResultIsFailure(success)) {
fprintf(stderr, "Unexpected failure running external pass.\n");
exit(EXIT_FAILURE);
Expand Down Expand Up @@ -421,7 +422,7 @@ void testExternalPass(void) {
MlirOpPassManager nestedFuncPm =
mlirPassManagerGetNestedUnder(pm, funcOpName);
mlirOpPassManagerAddOwnedPass(nestedFuncPm, externalPass);
MlirLogicalResult success = mlirPassManagerRun(pm, module);
MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
if (mlirLogicalResultIsFailure(success)) {
fprintf(stderr, "Unexpected failure running external operation pass.\n");
exit(EXIT_FAILURE);
Expand Down Expand Up @@ -469,7 +470,7 @@ void testExternalPass(void) {

MlirPassManager pm = mlirPassManagerCreate(ctx);
mlirPassManagerAddOwnedPass(pm, externalPass);
MlirLogicalResult success = mlirPassManagerRun(pm, module);
MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
if (mlirLogicalResultIsFailure(success)) {
fprintf(stderr, "Unexpected failure running external pass.\n");
exit(EXIT_FAILURE);
Expand Down Expand Up @@ -516,7 +517,7 @@ void testExternalPass(void) {

MlirPassManager pm = mlirPassManagerCreate(ctx);
mlirPassManagerAddOwnedPass(pm, externalPass);
MlirLogicalResult success = mlirPassManagerRun(pm, module);
MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
if (mlirLogicalResultIsSuccess(success)) {
fprintf(
stderr,
Expand Down Expand Up @@ -564,7 +565,7 @@ void testExternalPass(void) {

MlirPassManager pm = mlirPassManagerCreate(ctx);
mlirPassManagerAddOwnedPass(pm, externalPass);
MlirLogicalResult success = mlirPassManagerRun(pm, module);
MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
if (mlirLogicalResultIsSuccess(success)) {
fprintf(
stderr,
Expand All @@ -587,7 +588,7 @@ void testExternalPass(void) {
}

mlirTypeIDAllocatorDestroy(typeIDAllocator);
mlirModuleDestroy(module);
mlirOperationDestroy(module);
mlirContextDestroy(ctx);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __call__(self, module: ir.Module):

def compile(self, module: ir.Module):
"""Compiles the module by invoking the sparse copmiler pipeline."""
passmanager.PassManager.parse(self.pipeline).run(module)
passmanager.PassManager.parse(self.pipeline).run(module.operation)

def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
"""Wraps the module in a JIT execution engine."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __call__(self, module: ir.Module):

def compile(self, module: ir.Module):
"""Compiles the module by invoking the sparse copmiler pipeline."""
passmanager.PassManager.parse(self.pipeline).run(module)
passmanager.PassManager.parse(self.pipeline).run(module.operation)

def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
"""Wraps the module in a JIT execution engine."""
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/python/execution_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def testInvalidModule():
def lowerToLLVM(module):
pm = PassManager.parse(
"builtin.module(convert-complex-to-llvm,finalize-memref-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts)")
pm.run(module)
pm.run(module.operation)
return module


Expand Down
2 changes: 1 addition & 1 deletion mlir/test/python/integration/dialects/linalg/opsrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def transform(module, boilerplate):
pm.add("finalize-memref-to-llvm")
pm.add("convert-func-to-llvm")
pm.add("reconcile-unrealized-casts")
pm.run(mod)
pm.run(mod.operation)
return mod


Expand Down
31 changes: 30 additions & 1 deletion mlir/test/python/ir/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import io
import itertools
from mlir.ir import *
from mlir.dialects.builtin import ModuleOp


def run(f):
Expand Down Expand Up @@ -619,7 +620,7 @@ def testKnownOpView():
# addf should map to a known OpView class in the arithmetic dialect.
# We know the OpView for it defines an 'lhs' attribute.
addf = module.body.operations[2]
# CHECK: <mlir.dialects._arith_ops_gen._AddFOp object
# CHECK: <mlir.dialects._arith_ops_gen.AddFOp object
print(repr(addf))
# CHECK: "custom.f32"()
print(addf.lhs)
Expand Down Expand Up @@ -900,3 +901,31 @@ def testOperationHash():
with ctx, Location.unknown():
op = Operation.create("custom.op1")
assert hash(op) == hash(op.operation)


# CHECK-LABEL: TEST: testOperationParse
@run
def testOperationParse():
with Context() as ctx:
ctx.allow_unregistered_dialects = True

# Generic operation parsing.
m = Operation.parse('module {}')
o = Operation.parse('"test.foo"() : () -> ()')
assert isinstance(m, ModuleOp)
assert type(o) is OpView

# Parsing specific operation.
m = ModuleOp.parse('module {}')
assert isinstance(m, ModuleOp)
try:
ModuleOp.parse('"test.foo"() : () -> ()')
except ValueError as e:
# CHECK: error: Expected a 'builtin.module' op, got: 'test.foo'
print(f"error: {e}")
else:
assert False, "expected error"

o = Operation.parse('"test.foo"() : () -> ()', source_name="my-source-string")
# CHECK: op_with_source_name: "test.foo"() : () -> () loc("my-source-string":1:1)
print(f"op_with_source_name: {o.get_asm(enable_debug_info=True, use_local_scope=True)}")
8 changes: 4 additions & 4 deletions mlir/test/python/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import gc, sys
from mlir.ir import *
from mlir.passmanager import *
from mlir.dialects.func import FuncOp

# Log everything to stderr and flush so that we have a unified stream to match
# errors/info emitted by MLIR to stderr.
Expand Down Expand Up @@ -120,11 +121,10 @@ def testInvalidNesting():
# CHECK-LABEL: TEST: testRun
def testRunPipeline():
with Context():
pm = PassManager.parse("builtin.module(print-op-stats{json=false})")
module = Module.parse(r"""func.func @successfulParse() { return }""")
pm.run(module)
pm = PassManager.parse("any(print-op-stats{json=false})")
func = FuncOp.parse(r"""func.func @successfulParse() { return }""")
pm.run(func)
# CHECK: Operations encountered:
# CHECK: builtin.module , 1
# CHECK: func.func , 1
# CHECK: func.return , 1
run(testRunPipeline)