729 changes: 582 additions & 147 deletions mlir/lib/Bindings/Python/IRModules.cpp

Large diffs are not rendered by default.

263 changes: 185 additions & 78 deletions mlir/lib/Bindings/Python/IRModules.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,41 +17,83 @@
namespace mlir {
namespace python {

class PyBlock;
class PyLocation;
class PyMlirContext;
class PyModule;
class PyOperation;
class PyType;

/// Holds a C++ PyMlirContext and associated py::object, making it convenient
/// to have an auto-releasing C++-side keep-alive reference to the context.
/// The reference to the PyMlirContext is a simple C++ reference and the
/// py::object holds the reference count which keeps it alive.
class PyMlirContextRef {
/// Template for a reference to a concrete type which captures a python
/// reference to its underlying python object.
template <typename T>
class PyObjectRef {
public:
PyMlirContextRef(PyMlirContext &referrent, pybind11::object object)
: referrent(referrent), object(std::move(object)) {}
~PyMlirContextRef() {}
PyObjectRef(T *referrent, pybind11::object object)
: referrent(referrent), object(std::move(object)) {
assert(this->referrent &&
"cannot construct PyObjectRef with null referrent");
assert(this->object && "cannot construct PyObjectRef with null object");
}
PyObjectRef(PyObjectRef &&other)
: referrent(other.referrent), object(std::move(other.object)) {
other.referrent = nullptr;
assert(!other.object);
}
PyObjectRef(const PyObjectRef &other)
: referrent(other.referrent), object(other.object /* copies */) {}
~PyObjectRef() {}

int getRefCount() {
if (!object)
return 0;
return object.ref_count();
}

/// Releases the object held by this instance, causing its reference count
/// to remain artifically inflated by one. This must be used to return
/// the referenced PyMlirContext from a function. Otherwise, the destructor
/// of this reference would be called prior to the default take_ownership
/// policy assuming that the reference count has been transferred to it.
PyMlirContext *release();
/// Releases the object held by this instance, returning it.
/// This is the proper thing to return from a function that wants to return
/// the reference. Note that this does not work from initializers.
pybind11::object releaseObject() {
assert(referrent && object);
referrent = nullptr;
auto stolen = std::move(object);
return stolen;
}

PyMlirContext &operator->() { return referrent; }
pybind11::object getObject() { return object; }
T *operator->() {
assert(referrent && object);
return referrent;
}
pybind11::object getObject() {
assert(referrent && object);
return object;
}
operator bool() const { return referrent && object; }

private:
PyMlirContext &referrent;
T *referrent;
pybind11::object object;
};

using PyMlirContextRef = PyObjectRef<PyMlirContext>;

/// Wrapper around MlirContext.
class PyMlirContext {
public:
PyMlirContext() = delete;
PyMlirContext(const PyMlirContext &) = delete;
PyMlirContext(PyMlirContext &&) = delete;

/// For the case of a python __init__ (py::init) method, pybind11 is quite
/// strict about needing to return a pointer that is not yet associated to
/// an py::object. Since the forContext() method acts like a pool, possibly
/// returning a recycled context, it does not satisfy this need. The usual
/// way in python to accomplish such a thing is to override __new__, but
/// that is also not supported by pybind11. Instead, we use this entry
/// point which always constructs a fresh context (which cannot alias an
/// existing one because it is fresh).
static PyMlirContext *createNewContextForInit();

/// Returns a context reference for the singleton PyMlirContext wrapper for
/// the given context.
static PyMlirContextRef forContext(MlirContext context);
Expand All @@ -63,29 +105,45 @@ class PyMlirContext {
/// Gets a strong reference to this context, which will ensure it is kept
/// alive for the life of the reference.
PyMlirContextRef getRef() {
return PyMlirContextRef(
*this, pybind11::reinterpret_borrow<pybind11::object>(handle));
return PyMlirContextRef(this, pybind11::cast(this));
}

/// Gets the count of live context objects. Used for testing.
static size_t getLiveCount();

/// Gets the count of live operations associated with this context.
/// Used for testing.
size_t getLiveOperationCount();

/// Creates an operation. See corresponding python docstring.
pybind11::object
createOperation(std::string name, PyLocation location,
llvm::Optional<std::vector<PyType *>> results,
llvm::Optional<pybind11::dict> attributes,
llvm::Optional<std::vector<PyBlock *>> successors,
int regions);

private:
PyMlirContext(MlirContext context);

// Interns the mapping of live MlirContext::ptr to PyMlirContext instances,
// preserving the relationship that an MlirContext maps to a single
// PyMlirContext wrapper. This could be replaced in the future with an
// extension mechanism on the MlirContext for stashing user pointers.
// Note that this holds a handle, which does not imply ownership.
// Mappings will be removed when the context is destructed.
using LiveContextMap =
llvm::DenseMap<void *, std::pair<pybind11::handle, PyMlirContext *>>;
using LiveContextMap = llvm::DenseMap<void *, PyMlirContext *>;
static LiveContextMap &getLiveContexts();

// Interns all live operations associated with this context. Operations
// tracked in this map are valid. When an operation is invalidated, it is
// removed from this map, and while it still exists as an instance, any
// attempt to access it will raise an error.
using LiveOperationMap =
llvm::DenseMap<void *, std::pair<pybind11::handle, PyOperation *>>;
LiveOperationMap liveOperations;

MlirContext context;
// The handle is set as part of lookup with forContext() (post construction).
pybind11::handle handle;
friend class PyOperation;
};

/// Base class for all objects that directly or indirectly depend on an
Expand All @@ -94,7 +152,10 @@ class PyMlirContext {
/// Immutable objects that depend on a context extend this directly.
class BaseContextObject {
public:
BaseContextObject(PyMlirContextRef ref) : contextRef(std::move(ref)) {}
BaseContextObject(PyMlirContextRef ref) : contextRef(std::move(ref)) {
assert(this->contextRef &&
"context object constructed with null context ref");
}

/// Accesses the context reference.
PyMlirContextRef &getContext() { return contextRef; }
Expand All @@ -112,88 +173,134 @@ class PyLocation : public BaseContextObject {
};

/// Wrapper around MlirModule.
/// This is the top-level, user-owned object that contains regions/ops/blocks.
class PyModule;
using PyModuleRef = PyObjectRef<PyModule>;
class PyModule : public BaseContextObject {
public:
PyModule(PyMlirContextRef contextRef, MlirModule module)
: BaseContextObject(std::move(contextRef)), module(module) {}
/// Creates a reference to the module
static PyModuleRef create(PyMlirContextRef contextRef, MlirModule module);
PyModule(PyModule &) = delete;
PyModule(PyModule &&other)
: BaseContextObject(std::move(other.getContext())) {
module = other.module;
other.module.ptr = nullptr;
}
~PyModule() {
if (module.ptr)
mlirModuleDestroy(module);
}

/// Gets the backing MlirModule.
MlirModule get() { return module; }

/// Gets a strong reference to this module.
PyModuleRef getRef() {
return PyModuleRef(this,
pybind11::reinterpret_borrow<pybind11::object>(handle));
}

private:
PyModule(PyMlirContextRef contextRef, MlirModule module)
: BaseContextObject(std::move(contextRef)), module(module) {}
MlirModule module;
pybind11::handle handle;
};

/// Wrapper around PyOperation.
/// Operations exist in either an attached (dependent) or detached (top-level)
/// state. In the detached state (as on creation), an operation is owned by
/// the creator and its lifetime extends either until its reference count
/// drops to zero or it is attached to a parent, at which point its lifetime
/// is bounded by its top-level parent reference.
class PyOperation;
using PyOperationRef = PyObjectRef<PyOperation>;
class PyOperation : public BaseContextObject {
public:
~PyOperation();
/// Returns a PyOperation for the given MlirOperation, optionally associating
/// it with a parentKeepAlive (which must match on all such calls for the
/// same operation).
static PyOperationRef
forOperation(PyMlirContextRef contextRef, MlirOperation operation,
pybind11::object parentKeepAlive = pybind11::object());

/// Creates a detached operation. The operation must not be associated with
/// any existing live operation.
static PyOperationRef
createDetached(PyMlirContextRef contextRef, MlirOperation operation,
pybind11::object parentKeepAlive = pybind11::object());

/// Gets the backing operation.
MlirOperation get() {
checkValid();
return operation;
}

PyOperationRef getRef() {
return PyOperationRef(
this, pybind11::reinterpret_borrow<pybind11::object>(handle));
}

bool isAttached() { return attached; }
void setAttached() {
assert(!attached && "operation already attached");
attached = true;
}
void checkValid();

private:
PyOperation(PyMlirContextRef contextRef, MlirOperation operation);
static PyOperationRef createInstance(PyMlirContextRef contextRef,
MlirOperation operation,
pybind11::object parentKeepAlive);

MlirOperation operation;
pybind11::handle handle;
// Keeps the parent alive, regardless of whether it is an Operation or
// Module.
// TODO: As implemented, this facility is only sufficient for modeling the
// trivial module parent back-reference. Generalize this to also account for
// transitions from detached to attached and address TODOs in the
// ir_operation.py regarding testing corresponding lifetime guarantees.
pybind11::object parentKeepAlive;
bool attached = true;
bool valid = true;
};

/// Wrapper around an MlirRegion.
/// Note that region can exist in a detached state (where this instance is
/// responsible for clearing) or an attached state (where its owner is
/// responsible).
///
/// This python wrapper retains a redundant reference to its creating context
/// in order to facilitate checking that parts of the operation hierarchy
/// are only assembled from the same context.
/// Regions are managed completely by their containing operation. Unlike the
/// C++ API, the python API does not support detached regions.
class PyRegion {
public:
PyRegion(MlirContext context, MlirRegion region, bool detached)
: context(context), region(region), detached(detached) {}
PyRegion(PyRegion &&other)
: context(other.context), region(other.region), detached(other.detached) {
other.detached = false;
}
~PyRegion() {
if (detached)
mlirRegionDestroy(region);
PyRegion(PyOperationRef parentOperation, MlirRegion region)
: parentOperation(std::move(parentOperation)), region(region) {
assert(!mlirRegionIsNull(region) && "python region cannot be null");
}

// Call prior to attaching the region to a parent.
// This will transition to the attached state and will throw an exception
// if already attached.
void attachToParent();
MlirRegion get() { return region; }
PyOperationRef &getParentOperation() { return parentOperation; }

MlirContext context;
MlirRegion region;
void checkValid() { return parentOperation->checkValid(); }

private:
bool detached;
PyOperationRef parentOperation;
MlirRegion region;
};

/// Wrapper around an MlirBlock.
/// Note that blocks can exist in a detached state (where this instance is
/// responsible for clearing) or an attached state (where its owner is
/// responsible).
///
/// This python wrapper retains a redundant reference to its creating context
/// in order to facilitate checking that parts of the operation hierarchy
/// are only assembled from the same context.
/// Blocks are managed completely by their containing operation. Unlike the
/// C++ API, the python API does not support detached blocks.
class PyBlock {
public:
PyBlock(MlirContext context, MlirBlock block, bool detached)
: context(context), block(block), detached(detached) {}
PyBlock(PyBlock &&other)
: context(other.context), block(other.block), detached(other.detached) {
other.detached = false;
}
~PyBlock() {
if (detached)
mlirBlockDestroy(block);
PyBlock(PyOperationRef parentOperation, MlirBlock block)
: parentOperation(std::move(parentOperation)), block(block) {
assert(!mlirBlockIsNull(block) && "python block cannot be null");
}

// Call prior to attaching the block to a parent.
// This will transition to the attached state and will throw an exception
// if already attached.
void attachToParent();
MlirBlock get() { return block; }
PyOperationRef &getParentOperation() { return parentOperation; }

MlirContext context;
MlirBlock block;
void checkValid() { return parentOperation->checkValid(); }

private:
bool detached;
PyOperationRef parentOperation;
MlirBlock block;
};

/// Wrapper around the generic MlirAttribute.
Expand Down
8 changes: 8 additions & 0 deletions mlir/lib/Bindings/Python/PybindUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,16 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "llvm/ADT/Optional.h"
#include "llvm/ADT/Twine.h"

namespace pybind11 {
namespace detail {
template <typename T>
struct type_caster<llvm::Optional<T>> : optional_caster<llvm::Optional<T>> {};
} // namespace detail
} // namespace pybind11

namespace mlir {
namespace python {

Expand Down
1 change: 1 addition & 0 deletions mlir/lib/CAPI/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ add_mlir_library(MLIRCAPIIR
${MLIR_MAIN_INCLUDE_DIR}/mlir-c

LINK_LIBS PUBLIC
MLIRStandardOps
MLIRIR
MLIRParser
MLIRSupport
Expand Down
13 changes: 13 additions & 0 deletions mlir/lib/CAPI/IR/IR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Utils.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Module.h"
Expand All @@ -25,6 +26,10 @@ using namespace mlir;

MlirContext mlirContextCreate() {
auto *context = new MLIRContext(/*loadAllDialects=*/false);
// TODO: Come up with a story for which dialects to load into the context
// and do not expand this beyond StandardOps until done so. This is loaded
// by default here because it is hard to make progress otherwise.
context->loadDialect<StandardOpsDialect>();
return wrap(context);
}

Expand All @@ -34,6 +39,14 @@ int mlirContextEqual(MlirContext ctx1, MlirContext ctx2) {

void mlirContextDestroy(MlirContext context) { delete unwrap(context); }

void mlirContextSetAllowUnregisteredDialects(MlirContext context, int allow) {
unwrap(context)->allowUnregisteredDialects(allow);
}

int mlirContextGetAllowUnregisteredDialects(MlirContext context) {
return unwrap(context)->allowsUnregisteredDialects();
}

/* ========================================================================== */
/* Location API. */
/* ========================================================================== */
Expand Down
5 changes: 5 additions & 0 deletions mlir/test/Bindings/Python/ir_attributes.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
# RUN: %PYTHON %s | FileCheck %s

import gc
import mlir

def run(f):
print("\nTEST:", f.__name__)
f()
gc.collect()
assert mlir.ir.Context._get_live_count() == 0


# CHECK-LABEL: TEST: testParsePrint
def testParsePrint():
ctx = mlir.ir.Context()
t = ctx.parse_attr('"hello"')
ctx = None
gc.collect()
# CHECK: "hello"
print(str(t))
# CHECK: Attribute("hello")
Expand Down
8 changes: 8 additions & 0 deletions mlir/test/Bindings/Python/ir_location.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
# RUN: %PYTHON %s | FileCheck %s

import gc
import mlir

def run(f):
print("\nTEST:", f.__name__)
f()
gc.collect()
assert mlir.ir.Context._get_live_count() == 0


# CHECK-LABEL: TEST: testUnknown
def testUnknown():
ctx = mlir.ir.Context()
loc = ctx.get_unknown_location()
ctx = None
gc.collect()
# CHECK: unknown str: loc(unknown)
print("unknown str:", str(loc))
# CHECK: unknown repr: loc(unknown)
Expand All @@ -22,6 +28,8 @@ def testUnknown():
def testFileLineCol():
ctx = mlir.ir.Context()
loc = ctx.get_file_location("foo.txt", 123, 56)
ctx = None
gc.collect()
# CHECK: file str: loc("foo.txt":123:56)
print("file str:", str(loc))
# CHECK: file repr: loc("foo.txt":123:56)
Expand Down
37 changes: 37 additions & 0 deletions mlir/test/Bindings/Python/ir_module.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
# RUN: %PYTHON %s | FileCheck %s

import gc
import mlir

def run(f):
print("\nTEST:", f.__name__)
f()
gc.collect()
assert mlir.ir.Context._get_live_count() == 0


# Verify successful parse.
# CHECK-LABEL: TEST: testParseSuccess
# CHECK: module @successfulParse
def testParseSuccess():
ctx = mlir.ir.Context()
module = ctx.parse_module(r"""module @successfulParse {}""")
print("CLEAR CONTEXT")
ctx = None # Ensure that module captures the context.
gc.collect()
module.dump() # Just outputs to stderr. Verifies that it functions.
print(str(module))

Expand Down Expand Up @@ -47,3 +54,33 @@ def testRoundtripUnicode():
print(str(module))

run(testRoundtripUnicode)


# Tests that module.operation works and correctly interns instances.
# CHECK-LABEL: TEST: testModuleOperation
def testModuleOperation():
ctx = mlir.ir.Context()
module = ctx.parse_module(r"""module @successfulParse {}""")
op1 = module.operation
assert ctx._get_live_operation_count() == 1
# CHECK: module @successfulParse
print(op1)

# Ensure that operations are the same on multiple calls.
op2 = module.operation
assert ctx._get_live_operation_count() == 1
assert op1 is op2

# Ensure that if module is de-referenced, the operations are still valid.
module = None
gc.collect()
print(op1)

# Collect and verify lifetime.
op1 = None
op2 = None
gc.collect()
print("LIVE OPERATIONS:", ctx._get_live_operation_count())
assert ctx._get_live_operation_count() == 0

run(testModuleOperation)
222 changes: 174 additions & 48 deletions mlir/test/Bindings/Python/ir_operation.py
Original file line number Diff line number Diff line change
@@ -1,71 +1,197 @@
# RUN: %PYTHON %s | FileCheck %s

import gc
import itertools
import mlir

def run(f):
print("\nTEST:", f.__name__)
f()
gc.collect()
assert mlir.ir.Context._get_live_count() == 0


# CHECK-LABEL: TEST: testDetachedRegionBlock
def testDetachedRegionBlock():
# Verify iterator based traversal of the op/region/block hierarchy.
# CHECK-LABEL: TEST: testTraverseOpRegionBlockIterators
def testTraverseOpRegionBlockIterators():
ctx = mlir.ir.Context()
t = mlir.ir.F32Type(ctx)
region = ctx.create_region()
block = ctx.create_block([t, t])
# CHECK: <<UNLINKED BLOCK>>
print(block)
ctx.allow_unregistered_dialects = True
module = ctx.parse_module(r"""
func @f1(%arg0: i32) -> i32 {
%1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
return %1 : i32
}
""")
op = module.operation
# Get the block using iterators off of the named collections.
regions = list(op.regions)
blocks = list(regions[0].blocks)
# CHECK: MODULE REGIONS=1 BLOCKS=1
print(f"MODULE REGIONS={len(regions)} BLOCKS={len(blocks)}")

run(testDetachedRegionBlock)
# Get the regions and blocks from the default collections.
default_regions = list(op)
default_blocks = list(default_regions[0])
# They should compare equal regardless of how obtained.
assert default_regions == regions
assert default_blocks == blocks

# Should be able to get the operations from either the named collection
# or the block.
operations = list(blocks[0].operations)
default_operations = list(blocks[0])
assert default_operations == operations

# CHECK-LABEL: TEST: testBlockTypeContextMismatch
def testBlockTypeContextMismatch():
c1 = mlir.ir.Context()
c2 = mlir.ir.Context()
t1 = mlir.ir.F32Type(c1)
t2 = mlir.ir.F32Type(c2)
try:
block = c1.create_block([t1, t2])
except ValueError as e:
# CHECK: ERROR: All types used to construct a block must be from the same context as the block
print("ERROR:", e)
def walk_operations(indent, op):
for i, region in enumerate(op):
print(f"{indent}REGION {i}:")
for j, block in enumerate(region):
print(f"{indent} BLOCK {j}:")
for k, child_op in enumerate(block):
print(f"{indent} OP {k}: {child_op}")
walk_operations(indent + " ", child_op)

# CHECK: REGION 0:
# CHECK: BLOCK 0:
# CHECK: OP 0: func
# CHECK: REGION 0:
# CHECK: BLOCK 0:
# CHECK: OP 0: %0 = "custom.addi"
# CHECK: OP 1: return
# CHECK: OP 1: "module_terminator"
walk_operations("", op)

run(testBlockTypeContextMismatch)
run(testTraverseOpRegionBlockIterators)


# CHECK-LABEL: TEST: testBlockAppend
def testBlockAppend():
# Verify index based traversal of the op/region/block hierarchy.
# CHECK-LABEL: TEST: testTraverseOpRegionBlockIndices
def testTraverseOpRegionBlockIndices():
ctx = mlir.ir.Context()
t = mlir.ir.F32Type(ctx)
region = ctx.create_region()
try:
region.first_block
except IndexError:
pass
else:
raise RuntimeError("Expected exception not raised")
ctx.allow_unregistered_dialects = True
module = ctx.parse_module(r"""
func @f1(%arg0: i32) -> i32 {
%1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
return %1 : i32
}
""")

def walk_operations(indent, op):
for i in range(len(op.regions)):
region = op.regions[i]
print(f"{indent}REGION {i}:")
for j in range(len(region.blocks)):
block = region.blocks[j]
print(f"{indent} BLOCK {j}:")
for k in range(len(block.operations)):
child_op = block.operations[k]
print(f"{indent} OP {k}: {child_op}")
walk_operations(indent + " ", child_op)

# CHECK: REGION 0:
# CHECK: BLOCK 0:
# CHECK: OP 0: func
# CHECK: REGION 0:
# CHECK: BLOCK 0:
# CHECK: OP 0: %0 = "custom.addi"
# CHECK: OP 1: return
# CHECK: OP 1: "module_terminator"
walk_operations("", module.operation)

run(testTraverseOpRegionBlockIndices)


# CHECK-LABEL: TEST: testDetachedOperation
def testDetachedOperation():
ctx = mlir.ir.Context()
ctx.allow_unregistered_dialects = True
loc = ctx.get_unknown_location()
i32 = mlir.ir.IntegerType.get_signed(ctx, 32)
op1 = ctx.create_operation(
"custom.op1", loc, results=[i32, i32], regions=1, attributes={
"foo": mlir.ir.StringAttr.get(ctx, "foo_value"),
"bar": mlir.ir.StringAttr.get(ctx, "bar_value"),
})
# CHECK: %0:2 = "custom.op1"() ( {
# CHECK: }) {bar = "bar_value", foo = "foo_value"} : () -> (si32, si32)
print(op1)

# TODO: Check successors once enough infra exists to do it properly.

run(testDetachedOperation)

block = ctx.create_block([t, t])
region.append_block(block)

# CHECK-LABEL: TEST: testOperationInsert
def testOperationInsert():
ctx = mlir.ir.Context()
ctx.allow_unregistered_dialects = True
module = ctx.parse_module(r"""
func @f1(%arg0: i32) -> i32 {
%1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
return %1 : i32
}
""")

# Create test op.
loc = ctx.get_unknown_location()
op1 = ctx.create_operation("custom.op1", loc)
op2 = ctx.create_operation("custom.op2", loc)

func = module.operation.regions[0].blocks[0].operations[0]
entry_block = func.regions[0].blocks[0]
entry_block.operations.insert(0, op1)
entry_block.operations.insert(1, op2)
# CHECK: func @f1
# CHECK: "custom.op1"()
# CHECK: "custom.op2"()
# CHECK: %0 = "custom.addi"
print(module)

# Trying to add a previously added op should raise.
try:
region.append_block(block)
entry_block.operations.insert(0, op1)
except ValueError:
pass
else:
raise RuntimeError("Expected exception not raised")

block2 = ctx.create_block([t])
region.insert_block(1, block2)
# CHECK: <<UNLINKED BLOCK>>
block_first = region.first_block
print(block_first)
block_next = block_first.next_in_region
try:
block_next = block_next.next_in_region
except IndexError:
pass
else:
raise RuntimeError("Expected exception not raised")
assert False, "expected insert of attached op to raise"

run(testOperationInsert)


# CHECK-LABEL: TEST: testOperationWithRegion
def testOperationWithRegion():
ctx = mlir.ir.Context()
ctx.allow_unregistered_dialects = True
loc = ctx.get_unknown_location()
i32 = mlir.ir.IntegerType.get_signed(ctx, 32)
op1 = ctx.create_operation("custom.op1", loc, regions=1)
block = op1.regions[0].blocks.append(i32, i32)
# CHECK: "custom.op1"() ( {
# CHECK: ^bb0(%arg0: si32, %arg1: si32): // no predecessors
# CHECK: "custom.terminator"() : () -> ()
# CHECK: }) : () -> ()
terminator = ctx.create_operation("custom.terminator", loc)
block.operations.insert(0, terminator)
print(op1)

# Now add the whole operation to another op.
# TODO: Verify lifetime hazard by nulling out the new owning module and
# accessing op1.
# TODO: Also verify accessing the terminator once both parents are nulled
# out.
module = ctx.parse_module(r"""
func @f1(%arg0: i32) -> i32 {
%1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
return %1 : i32
}
""")
func = module.operation.regions[0].blocks[0].operations[0]
entry_block = func.regions[0].blocks[0]
entry_block.operations.insert(0, op1)
# CHECK: func @f1
# CHECK: "custom.op1"()
# CHECK: "custom.terminator"
# CHECK: %0 = "custom.addi"
print(module)

run(testBlockAppend)
run(testOperationWithRegion)
5 changes: 5 additions & 0 deletions mlir/test/Bindings/Python/ir_types.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
# RUN: %PYTHON %s | FileCheck %s

import gc
import mlir

def run(f):
print("\nTEST:", f.__name__)
f()
gc.collect()
assert mlir.ir.Context._get_live_count() == 0


# CHECK-LABEL: TEST: testParsePrint
def testParsePrint():
ctx = mlir.ir.Context()
t = ctx.parse_type("i32")
ctx = None
gc.collect()
# CHECK: i32
print(str(t))
# CHECK: Type(i32)
Expand Down