Skip to content

Commit

Permalink
[mlir][python] Extend C/Python API to be usable for CFG construction.
Browse files Browse the repository at this point in the history
* It is pretty clear that no one has tried this yet since it was both incomplete and broken.
* Fixes a symbol hiding issues keeping even the generic builder from constructing an operation with successors.
* Adds ODS support for successors.
* Adds CAPI `mlirBlockGetParentRegion`, `mlirRegionEqual` + tests (and missing test for `mlirBlockGetParentOperation`).
* Adds Python property: `Block.region`.
* Adds Python methods: `Block.create_before` and `Block.create_after`.
* Adds Python property: `InsertionPoint.block`.
* Adds new blocks.py test to verify a plausible CFG construction case.

Differential Revision: https://reviews.llvm.org/D108898
  • Loading branch information
stellaraccident committed Aug 30, 2021
1 parent 57b4605 commit 8e6c55c
Show file tree
Hide file tree
Showing 7 changed files with 223 additions and 35 deletions.
7 changes: 7 additions & 0 deletions mlir/include/mlir-c/IR.h
Expand Up @@ -447,6 +447,10 @@ MLIR_CAPI_EXPORTED void mlirRegionDestroy(MlirRegion region);
/// Checks whether a region is null.
static inline bool mlirRegionIsNull(MlirRegion region) { return !region.ptr; }

/// Checks whether two region handles point to the same region. This does not
/// perform deep comparison.
MLIR_CAPI_EXPORTED bool mlirRegionEqual(MlirRegion region, MlirRegion other);

/// Gets the first block in the region.
MLIR_CAPI_EXPORTED MlirBlock mlirRegionGetFirstBlock(MlirRegion region);

Expand Down Expand Up @@ -496,6 +500,9 @@ MLIR_CAPI_EXPORTED bool mlirBlockEqual(MlirBlock block, MlirBlock other);
/// Returns the closest surrounding operation that contains this block.
MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetParentOperation(MlirBlock);

/// Returns the region that contains this block.
MLIR_CAPI_EXPORTED MlirRegion mlirBlockGetParentRegion(MlirBlock block);

/// Returns the block immediately following the given block in its parent
/// region.
MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetNextInRegion(MlirBlock block);
Expand Down
47 changes: 45 additions & 2 deletions mlir/lib/Bindings/Python/IRCore.cpp
Expand Up @@ -969,7 +969,6 @@ py::object PyOperation::create(
}
// Unpack/validate successors.
if (successors) {
llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
mlirSuccessors.reserve(successors->size());
for (auto *successor : *successors) {
// TODO: Verify successor originate from the same context.
Expand Down Expand Up @@ -2206,6 +2205,13 @@ void mlir::python::populateIRCore(py::module &m) {
return self.getParentOperation()->createOpView();
},
"Returns the owning operation of this block.")
.def_property_readonly(
"region",
[](PyBlock &self) {
MlirRegion region = mlirBlockGetParentRegion(self.get());
return PyRegion(self.getParentOperation(), region);
},
"Returns the owning region of this block.")
.def_property_readonly(
"arguments",
[](PyBlock &self) {
Expand All @@ -2218,6 +2224,40 @@ void mlir::python::populateIRCore(py::module &m) {
return PyOperationList(self.getParentOperation(), self.get());
},
"Returns a forward-optimized sequence of operations.")
.def(
"create_before",
[](PyBlock &self, py::args pyArgTypes) {
self.checkValid();
llvm::SmallVector<MlirType, 4> argTypes;
argTypes.reserve(pyArgTypes.size());
for (auto &pyArg : pyArgTypes) {
argTypes.push_back(pyArg.cast<PyType &>());
}

MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
MlirRegion region = mlirBlockGetParentRegion(self.get());
mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
return PyBlock(self.getParentOperation(), block);
},
"Creates and returns a new Block before this block "
"(with given argument types).")
.def(
"create_after",
[](PyBlock &self, py::args pyArgTypes) {
self.checkValid();
llvm::SmallVector<MlirType, 4> argTypes;
argTypes.reserve(pyArgTypes.size());
for (auto &pyArg : pyArgTypes) {
argTypes.push_back(pyArg.cast<PyType &>());
}

MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
MlirRegion region = mlirBlockGetParentRegion(self.get());
mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
return PyBlock(self.getParentOperation(), block);
},
"Creates and returns a new Block after this block "
"(with given argument types).")
.def(
"__iter__",
[](PyBlock &self) {
Expand Down Expand Up @@ -2270,7 +2310,10 @@ void mlir::python::populateIRCore(py::module &m) {
.def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
py::arg("block"), "Inserts before the block terminator.")
.def("insert", &PyInsertionPoint::insert, py::arg("operation"),
"Inserts an operation.");
"Inserts an operation.")
.def_property_readonly(
"block", [](PyInsertionPoint &self) { return self.getBlock(); },
"Returns the block that this InsertionPoint points to.");

//----------------------------------------------------------------------------
// Mapping of PyAttribute.
Expand Down
8 changes: 8 additions & 0 deletions mlir/lib/CAPI/IR/IR.cpp
Expand Up @@ -427,6 +427,10 @@ bool mlirOperationVerify(MlirOperation op) {

MlirRegion mlirRegionCreate() { return wrap(new Region); }

bool mlirRegionEqual(MlirRegion region, MlirRegion other) {
return unwrap(region) == unwrap(other);
}

MlirBlock mlirRegionGetFirstBlock(MlirRegion region) {
Region *cppRegion = unwrap(region);
if (cppRegion->empty())
Expand Down Expand Up @@ -492,6 +496,10 @@ MlirOperation mlirBlockGetParentOperation(MlirBlock block) {
return wrap(unwrap(block)->getParentOp());
}

MlirRegion mlirBlockGetParentRegion(MlirBlock block) {
return wrap(unwrap(block)->getParent());
}

MlirBlock mlirBlockGetNextInRegion(MlirBlock block) {
return wrap(unwrap(block)->getNextNode());
}
Expand Down
39 changes: 24 additions & 15 deletions mlir/test/CAPI/ir.c
Expand Up @@ -323,13 +323,20 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
assert(mlirModuleIsNull(mlirModuleFromOperation(operation)));

// Verify that parent operation and block report correctly.
// CHECK: Parent operation eq: 1
fprintf(stderr, "Parent operation eq: %d\n",
mlirOperationEqual(mlirOperationGetParentOperation(operation),
parentOperation));
// CHECK: Block eq: 1
fprintf(stderr, "Block eq: %d\n",
mlirBlockEqual(mlirOperationGetBlock(operation), block));
// CHECK: Parent operation eq: 1
// CHECK: Block eq: 1
// CHECK: Block parent operation eq: 1
fprintf(
stderr, "Block parent operation eq: %d\n",
mlirOperationEqual(mlirBlockGetParentOperation(block), parentOperation));
// CHECK: Block parent region eq: 1
fprintf(stderr, "Block parent region eq: %d\n",
mlirRegionEqual(mlirBlockGetParentRegion(block), region));

// In the module we created, the first operation of the first function is
// an "memref.dim", which has an attribute and a single result that we can
Expand Down Expand Up @@ -441,7 +448,8 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
operation, mlirStringRefCreateFromCString("elts"),
mlirDenseElementsAttrInt32Get(
mlirRankedTensorTypeGet(1, eltsShape, mlirIntegerTypeGet(ctx, 32),
mlirAttributeGetNull()), 4, eltsData));
mlirAttributeGetNull()),
4, eltsData));
MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
mlirOpPrintingFlagsElideLargeElementsAttrs(flags, 2);
mlirOpPrintingFlagsPrintGenericOpForm(flags);
Expand Down Expand Up @@ -909,25 +917,25 @@ int printBuiltinAttributes(MlirContext ctx) {
mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 8), encoding),
2, ints8);
MlirAttribute uint32Elements = mlirDenseElementsAttrUInt32Get(
mlirRankedTensorTypeGet(2, shape,
mlirIntegerTypeUnsignedGet(ctx, 32), encoding),
mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 32),
encoding),
2, uints32);
MlirAttribute int32Elements = mlirDenseElementsAttrInt32Get(
mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 32), encoding),
2, ints32);
MlirAttribute uint64Elements = mlirDenseElementsAttrUInt64Get(
mlirRankedTensorTypeGet(2, shape,
mlirIntegerTypeUnsignedGet(ctx, 64), encoding),
mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 64),
encoding),
2, uints64);
MlirAttribute int64Elements = mlirDenseElementsAttrInt64Get(
mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64), encoding),
2, ints64);
MlirAttribute floatElements = mlirDenseElementsAttrFloatGet(
mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx), encoding),
2, floats);
mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx), encoding), 2,
floats);
MlirAttribute doubleElements = mlirDenseElementsAttrDoubleGet(
mlirRankedTensorTypeGet(2, shape, mlirF64TypeGet(ctx), encoding),
2, doubles);
mlirRankedTensorTypeGet(2, shape, mlirF64TypeGet(ctx), encoding), 2,
doubles);

if (!mlirAttributeIsADenseElements(boolElements) ||
!mlirAttributeIsADenseElements(uint8Elements) ||
Expand Down Expand Up @@ -1084,8 +1092,8 @@ int printBuiltinAttributes(MlirContext ctx) {
mlirRankedTensorTypeGet(1, &two, mlirIntegerTypeGet(ctx, 64), encoding),
2, indices);
MlirAttribute valuesAttr = mlirDenseElementsAttrFloatGet(
mlirRankedTensorTypeGet(1, &two, mlirF32TypeGet(ctx), encoding),
2, floats);
mlirRankedTensorTypeGet(1, &two, mlirF32TypeGet(ctx), encoding), 2,
floats);
MlirAttribute sparseAttr = mlirSparseElementsAttribute(
mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx), encoding),
indicesAttr, valuesAttr);
Expand Down Expand Up @@ -1635,11 +1643,12 @@ int testClone() {
mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("std"));
MlirLocation loc = mlirLocationUnknownGet(ctx);
MlirType indexType = mlirIndexTypeGet(ctx);
MlirStringRef valueStringRef = mlirStringRefCreateFromCString("value");
MlirStringRef valueStringRef = mlirStringRefCreateFromCString("value");

MlirAttribute indexZeroLiteral =
mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0 : index"));
MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet(mlirIdentifierGet(ctx, valueStringRef), indexZeroLiteral);
MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet(
mlirIdentifierGet(ctx, valueStringRef), indexZeroLiteral);
MlirOperationState constZeroState = mlirOperationStateGet(
mlirStringRefCreateFromCString("std.constant"), loc);
mlirOperationStateAddResults(&constZeroState, 1, &indexType);
Expand Down
42 changes: 32 additions & 10 deletions mlir/test/mlir-tblgen/op-python-bindings.td
Expand Up @@ -27,9 +27,10 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
// CHECK: operands.append(variadic1)
// CHECK: operands.append(non_variadic)
// CHECK: if variadic2 is not None: operands.append(variadic2)
// CHECK: _ods_successors = None
// CHECK: super().__init__(self.build_generic(
// CHECK: attributes=attributes, results=results, operands=operands,
// CHECK: loc=loc, ip=ip))
// CHECK: successors=_ods_successors, loc=loc, ip=ip))

// CHECK: @builtins.property
// CHECK: def variadic1(self):
Expand Down Expand Up @@ -68,9 +69,10 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
// CHECK: if variadic1 is not None: results.append(variadic1)
// CHECK: results.append(non_variadic)
// CHECK: if variadic2 is not None: results.append(variadic2)
// CHECK: _ods_successors = None
// CHECK: super().__init__(self.build_generic(
// CHECK: attributes=attributes, results=results, operands=operands,
// CHECK: loc=loc, ip=ip))
// CHECK: successors=_ods_successors, loc=loc, ip=ip))

// CHECK: @builtins.property
// CHECK: def variadic1(self):
Expand Down Expand Up @@ -112,9 +114,10 @@ def AttributedOp : TestOp<"attributed_op"> {
// CHECK: if bool(unitAttr): attributes["unitAttr"] = _ods_ir.UnitAttr.get(
// CHECK: _ods_get_default_loc_context(loc))
// CHECK: attributes["in"] = in_
// CHECK: _ods_successors = None
// CHECK: super().__init__(self.build_generic(
// CHECK: attributes=attributes, results=results, operands=operands,
// CHECK: loc=loc, ip=ip))
// CHECK: successors=_ods_successors, loc=loc, ip=ip))

// CHECK: @builtins.property
// CHECK: def i32attr(self):
Expand Down Expand Up @@ -152,9 +155,10 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
// CHECK: if bool(in_): attributes["in"] = _ods_ir.UnitAttr.get(
// CHECK: _ods_get_default_loc_context(loc))
// CHECK: if is_ is not None: attributes["is"] = is_
// CHECK: _ods_successors = None
// CHECK: super().__init__(self.build_generic(
// CHECK: attributes=attributes, results=results, operands=operands,
// CHECK: loc=loc, ip=ip))
// CHECK: successors=_ods_successors, loc=loc, ip=ip))

// CHECK: @builtins.property
// CHECK: def in_(self):
Expand All @@ -177,9 +181,10 @@ def EmptyOp : TestOp<"empty">;
// CHECK: operands = []
// CHECK: results = []
// CHECK: attributes = {}
// CHECK: _ods_successors = None
// CHECK: super().__init__(self.build_generic(
// CHECK: attributes=attributes, results=results, operands=operands,
// CHECK: loc=loc, ip=ip))
// CHECK: successors=_ods_successors, loc=loc, ip=ip))

// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class MissingNamesOp(_ods_ir.OpView):
Expand All @@ -195,9 +200,10 @@ def MissingNamesOp : TestOp<"missing_names"> {
// CHECK: operands.append(_gen_arg_0)
// CHECK: operands.append(f32)
// CHECK: operands.append(_gen_arg_2)
// CHECK: _ods_successors = None
// CHECK: super().__init__(self.build_generic(
// CHECK: attributes=attributes, results=results, operands=operands,
// CHECK: loc=loc, ip=ip))
// CHECK: successors=_ods_successors, loc=loc, ip=ip))

// CHECK: @builtins.property
// CHECK: def f32(self):
Expand Down Expand Up @@ -226,9 +232,10 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
// CHECK: attributes = {}
// CHECK: operands.append(non_variadic)
// CHECK: operands.extend(variadic)
// CHECK: _ods_successors = None
// CHECK: super().__init__(self.build_generic(
// CHECK: attributes=attributes, results=results, operands=operands,
// CHECK: loc=loc, ip=ip))
// CHECK: successors=_ods_successors, loc=loc, ip=ip))

// CHECK: @builtins.property
// CHECK: def non_variadic(self):
Expand All @@ -253,9 +260,10 @@ def OneVariadicResultOp : TestOp<"one_variadic_result"> {
// CHECK: attributes = {}
// CHECK: results.extend(variadic)
// CHECK: results.append(non_variadic)
// CHECK: _ods_successors = None
// CHECK: super().__init__(self.build_generic(
// CHECK: attributes=attributes, results=results, operands=operands,
// CHECK: loc=loc, ip=ip))
// CHECK: successors=_ods_successors, loc=loc, ip=ip))

// CHECK: @builtins.property
// CHECK: def variadic(self):
Expand All @@ -278,9 +286,10 @@ def PythonKeywordOp : TestOp<"python_keyword"> {
// CHECK: results = []
// CHECK: attributes = {}
// CHECK: operands.append(in_)
// CHECK: _ods_successors = None
// CHECK: super().__init__(self.build_generic(
// CHECK: attributes=attributes, results=results, operands=operands,
// CHECK: loc=loc, ip=ip))
// CHECK: successors=_ods_successors, loc=loc, ip=ip))

// CHECK: @builtins.property
// CHECK: def in_(self):
Expand Down Expand Up @@ -346,9 +355,10 @@ def SimpleOp : TestOp<"simple"> {
// CHECK: results.append(f64)
// CHECK: operands.append(i32)
// CHECK: operands.append(f32)
// CHECK: _ods_successors = None
// CHECK: super().__init__(self.build_generic(
// CHECK: attributes=attributes, results=results, operands=operands,
// CHECK: loc=loc, ip=ip))
// CHECK: successors=_ods_successors, loc=loc, ip=ip))

// CHECK: @builtins.property
// CHECK: def i32(self):
Expand All @@ -368,3 +378,15 @@ def SimpleOp : TestOp<"simple"> {
// CHECK: return self.operation.results[1]
let results = (outs I64:$i64, F64:$f64);
}

// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class WithSuccessorsOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.with_successors"
def WithSuccessorsOp : TestOp<"with_successors"> {
// CHECK-NOT: _ods_successors = None
// CHECK: _ods_successors = []
// CHECK-NEXT: _ods_successors.append(successor)
// CHECK-NEXT: _ods_successors.extend(successors)
let successors = (successor AnySuccessor:$successor,
VariadicSuccessor<AnySuccessor>:$successors);
}

0 comments on commit 8e6c55c

Please sign in to comment.