Skip to content

Commit

Permalink
[mlir][python] value casting (#69644)
Browse files Browse the repository at this point in the history
This PR adds "value casting", i.e., a mechanism to wrap `ir.Value` in a
proxy class that overloads dunders such as `__add__`, `__sub__`, and
`__mul__` for fun and great profit.

This is thematically similar to
bfb1ba7
and
9566ee2.
The example in the test demonstrates the value of the feature (no pun
intended):

```python
    @register_value_caster(F16Type.static_typeid)
    @register_value_caster(F32Type.static_typeid)
    @register_value_caster(F64Type.static_typeid)
    @register_value_caster(IntegerType.static_typeid)
    class ArithValue(Value):
        __add__ = partialmethod(_binary_op, op="add")
        __sub__ = partialmethod(_binary_op, op="sub")
        __mul__ = partialmethod(_binary_op, op="mul")

    a = arith.constant(value=FloatAttr.get(f16_t, 42.42))
    b = a + a
    # CHECK: ArithValue(%0 = arith.addf %cst, %cst : f16)
    print(b)

    a = arith.constant(value=FloatAttr.get(f32_t, 42.42))
    b = a - a
    # CHECK: ArithValue(%1 = arith.subf %cst_0, %cst_0 : f32)
    print(b)

    a = arith.constant(value=FloatAttr.get(f64_t, 42.42))
    b = a * a
    # CHECK: ArithValue(%2 = arith.mulf %cst_1, %cst_1 : f64)
    print(b)
```

**EDIT**: this now goes through the bindings and thus supports automatic
casting of `OpResult` (including as an element of `OpResultList`),
`BlockArgument` (including as an element of `BlockArgumentList`), as
well as `Value`.
  • Loading branch information
makslevental committed Nov 7, 2023
1 parent 867ece1 commit 7c85086
Show file tree
Hide file tree
Showing 16 changed files with 371 additions and 58 deletions.
23 changes: 19 additions & 4 deletions mlir/include/mlir-c/Bindings/Python/Interop.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,28 @@

/** Attribute on main C extension module (_mlir) that corresponds to the
* type caster registration binding. The signature of the function is:
* def register_type_caster(MlirTypeID mlirTypeID, py::function typeCaster,
* bool replace)
* where replace indicates the typeCaster should replace any existing registered
* type casters (such as those for upstream ConcreteTypes).
* def register_type_caster(MlirTypeID mlirTypeID, *, bool replace)
* which then takes a typeCaster (register_type_caster is meant to be used as a
* decorator from python), and where replace indicates the typeCaster should
* replace any existing registered type casters (such as those for upstream
* ConcreteTypes). The interface of the typeCaster is: def type_caster(ir.Type)
* -> SubClassTypeT where SubClassTypeT indicates the result should be a
* subclass (inherit from) ir.Type.
*/
#define MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR "register_type_caster"

/** Attribute on main C extension module (_mlir) that corresponds to the
* value caster registration binding. The signature of the function is:
* def register_value_caster(MlirTypeID mlirTypeID, *, bool replace)
* which then takes a valueCaster (register_value_caster is meant to be used as
* a decorator, from python), and where replace indicates the valueCaster should
* replace any existing registered value casters. The interface of the
* valueCaster is: def value_caster(ir.Value) -> SubClassValueT where
* SubClassValueT indicates the result should be a subclass (inherit from)
* ir.Value.
*/
#define MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR "register_value_caster"

/// Gets a void* from a wrapped struct. Needed because const cast is different
/// between C/C++.
#ifdef __cplusplus
Expand Down
10 changes: 5 additions & 5 deletions mlir/include/mlir/Bindings/Python/PybindAdaptors.h
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ struct type_caster<MlirValue> {
return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
.attr("Value")
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
.attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)()
.release();
};
};
Expand Down Expand Up @@ -496,11 +497,10 @@ class mlir_type_subclass : public pure_subclass {
if (getTypeIDFunction) {
py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
.attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
getTypeIDFunction(),
pybind11::cpp_function(
[thisClass = thisClass](const py::object &mlirType) {
return thisClass(mlirType);
}));
getTypeIDFunction())(pybind11::cpp_function(
[thisClass = thisClass](const py::object &mlirType) {
return thisClass(mlirType);
}));
}
}
};
Expand Down
14 changes: 13 additions & 1 deletion mlir/lib/Bindings/Python/Globals.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,13 @@ class PyGlobals {
void registerTypeCaster(MlirTypeID mlirTypeID, pybind11::function typeCaster,
bool replace = false);

/// Adds a user-friendly value caster. Raises an exception if the mapping
/// already exists and replace == false. This is intended to be called by
/// implementation code.
void registerValueCaster(MlirTypeID mlirTypeID,
pybind11::function valueCaster,
bool replace = false);

/// Adds a concrete implementation dialect class.
/// Raises an exception if the mapping already exists.
/// This is intended to be called by implementation code.
Expand All @@ -86,6 +93,10 @@ class PyGlobals {
std::optional<pybind11::function> lookupTypeCaster(MlirTypeID mlirTypeID,
MlirDialect dialect);

/// Returns the custom value caster for MlirTypeID mlirTypeID.
std::optional<pybind11::function> lookupValueCaster(MlirTypeID mlirTypeID,
MlirDialect dialect);

/// Looks up a registered dialect class by namespace. Note that this may
/// trigger loading of the defining module and can arbitrarily re-enter.
std::optional<pybind11::object>
Expand All @@ -109,7 +120,8 @@ class PyGlobals {
llvm::StringMap<pybind11::object> attributeBuilderMap;
/// Map of MlirTypeID to custom type caster.
llvm::DenseMap<MlirTypeID, pybind11::object> typeCasterMap;

/// Map of MlirTypeID to custom value caster.
llvm::DenseMap<MlirTypeID, pybind11::object> valueCasterMap;
/// Set of dialect namespaces that we have attempted to import implementation
/// modules for.
llvm::StringSet<> loadedDialectModules;
Expand Down
31 changes: 27 additions & 4 deletions mlir/lib/Bindings/Python/IRCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1899,13 +1899,28 @@ bool PyTypeID::operator==(const PyTypeID &other) const {
}

//------------------------------------------------------------------------------
// PyValue and subclases.
// PyValue and subclasses.
//------------------------------------------------------------------------------

pybind11::object PyValue::getCapsule() {
return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get()));
}

pybind11::object PyValue::maybeDownCast() {
MlirType type = mlirValueGetType(get());
MlirTypeID mlirTypeID = mlirTypeGetTypeID(type);
assert(!mlirTypeIDIsNull(mlirTypeID) &&
"mlirTypeID was expected to be non-null.");
std::optional<pybind11::function> valueCaster =
PyGlobals::get().lookupValueCaster(mlirTypeID, mlirTypeGetDialect(type));
// py::return_value_policy::move means use std::move to move the return value
// contents into a new instance that will be owned by Python.
py::object thisObj = py::cast(this, py::return_value_policy::move);
if (!valueCaster)
return thisObj;
return valueCaster.value()(thisObj);
}

PyValue PyValue::createFromCapsule(pybind11::object capsule) {
MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
if (mlirValueIsNull(value))
Expand Down Expand Up @@ -2121,6 +2136,8 @@ class PyConcreteValue : public PyValue {
return DerivedTy::isaFunction(otherValue);
},
py::arg("other_value"));
cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
[](DerivedTy &self) { return self.maybeDownCast(); });
DerivedTy::bindDerived(cls);
}

Expand Down Expand Up @@ -2193,6 +2210,7 @@ class PyBlockArgumentList
: public Sliceable<PyBlockArgumentList, PyBlockArgument> {
public:
static constexpr const char *pyClassName = "BlockArgumentList";
using SliceableT = Sliceable<PyBlockArgumentList, PyBlockArgument>;

PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
intptr_t startIndex = 0, intptr_t length = -1,
Expand Down Expand Up @@ -2241,6 +2259,7 @@ class PyBlockArgumentList
class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
public:
static constexpr const char *pyClassName = "OpOperandList";
using SliceableT = Sliceable<PyOpOperandList, PyValue>;

PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
intptr_t length = -1, intptr_t step = 1)
Expand Down Expand Up @@ -2296,14 +2315,15 @@ class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
public:
static constexpr const char *pyClassName = "OpResultList";
using SliceableT = Sliceable<PyOpResultList, PyOpResult>;

PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
intptr_t length = -1, intptr_t step = 1)
: Sliceable(startIndex,
length == -1 ? mlirOperationGetNumResults(operation->get())
: length,
step),
operation(operation) {}
operation(std::move(operation)) {}

static void bindDerived(ClassTy &c) {
c.def_property_readonly("types", [](PyOpResultList &self) {
Expand Down Expand Up @@ -2892,7 +2912,8 @@ void mlir::python::populateIRCore(py::module &m) {
.str());
}
return PyOpResult(operation.getRef(),
mlirOperationGetResult(operation, 0));
mlirOperationGetResult(operation, 0))
.maybeDownCast();
},
"Shortcut to get an op result if it has only one (throws an error "
"otherwise).")
Expand Down Expand Up @@ -3566,7 +3587,9 @@ void mlir::python::populateIRCore(py::module &m) {
[](PyValue &self, PyValue &with) {
mlirValueReplaceAllUsesOfWith(self.get(), with.get());
},
kValueReplaceAllUsesWithDocstring);
kValueReplaceAllUsesWithDocstring)
.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
[](PyValue &self) { return self.maybeDownCast(); });
PyBlockArgument::bind(m);
PyOpResult::bind(m);
PyOpOperand::bind(m);
Expand Down
21 changes: 21 additions & 0 deletions mlir/lib/Bindings/Python/IRModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,16 @@ void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
found = std::move(typeCaster);
}

void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID,
pybind11::function valueCaster,
bool replace) {
pybind11::object &found = valueCasterMap[mlirTypeID];
if (found && !replace)
throw std::runtime_error("Value caster is already registered: " +
py::repr(found).cast<std::string>());
found = std::move(valueCaster);
}

void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
py::object pyClass) {
py::object &found = dialectClassMap[dialectNamespace];
Expand Down Expand Up @@ -134,6 +144,17 @@ std::optional<py::function> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID,
return std::nullopt;
}

std::optional<py::function> PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID,
MlirDialect dialect) {
loadDialectModule(unwrap(mlirDialectGetNamespace(dialect)));
const auto foundIt = valueCasterMap.find(mlirTypeID);
if (foundIt != valueCasterMap.end()) {
assert(foundIt->second && "value caster is defined");
return foundIt->second;
}
return std::nullopt;
}

std::optional<py::object>
PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
// Make sure dialect module is loaded.
Expand Down
14 changes: 9 additions & 5 deletions mlir/lib/Bindings/Python/IRModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -761,7 +761,7 @@ class PyRegion {

/// Wrapper around an MlirAsmState.
class PyAsmState {
public:
public:
PyAsmState(MlirValue value, bool useLocalScope) {
flags = mlirOpPrintingFlagsCreate();
// The OpPrintingFlags are not exposed Python side, create locally and
Expand All @@ -780,16 +780,14 @@ class PyAsmState {
state =
mlirAsmStateCreateForOperation(operation.getOperation().get(), flags);
}
~PyAsmState() {
mlirOpPrintingFlagsDestroy(flags);
}
~PyAsmState() { mlirOpPrintingFlagsDestroy(flags); }
// Delete copy constructors.
PyAsmState(PyAsmState &other) = delete;
PyAsmState(const PyAsmState &other) = delete;

MlirAsmState get() { return state; }

private:
private:
MlirAsmState state;
MlirOpPrintingFlags flags;
};
Expand Down Expand Up @@ -1112,6 +1110,10 @@ class PyConcreteAttribute : public BaseTy {
/// bindings so such operation always exists).
class PyValue {
public:
// The virtual here is "load bearing" in that it enables RTTI
// for PyConcreteValue CRTP classes that support maybeDownCast.
// See PyValue::maybeDownCast.
virtual ~PyValue() = default;
PyValue(PyOperationRef parentOperation, MlirValue value)
: parentOperation(std::move(parentOperation)), value(value) {}
operator MlirValue() const { return value; }
Expand All @@ -1124,6 +1126,8 @@ class PyValue {
/// Gets a capsule wrapping the void* within the MlirValue.
pybind11::object getCapsule();

pybind11::object maybeDownCast();

/// Creates a PyValue from the MlirValue wrapped by a capsule. Ownership of
/// the underlying MlirValue is still tied to the owning operation.
static PyValue createFromCapsule(pybind11::object capsule);
Expand Down
30 changes: 22 additions & 8 deletions mlir/lib/Bindings/Python/MainModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
#include "IRModule.h"
#include "Pass.h"

#include <tuple>

namespace py = pybind11;
using namespace mlir;
using namespace py::literals;
Expand Down Expand Up @@ -46,7 +44,8 @@ PYBIND11_MODULE(_mlir, m) {
"dialect_namespace"_a, "dialect_class"_a,
"Testing hook for directly registering a dialect")
.def("_register_operation_impl", &PyGlobals::registerOperationImpl,
"operation_name"_a, "operation_class"_a, "replace"_a = false,
"operation_name"_a, "operation_class"_a, py::kw_only(),
"replace"_a = false,
"Testing hook for directly registering an operation");

// Aside from making the globals accessible to python, having python manage
Expand Down Expand Up @@ -82,17 +81,32 @@ PYBIND11_MODULE(_mlir, m) {
return opClass;
});
},
"dialect_class"_a, "replace"_a = false,
"dialect_class"_a, py::kw_only(), "replace"_a = false,
"Produce a class decorator for registering an Operation class as part of "
"a dialect");
m.def(
MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR,
[](MlirTypeID mlirTypeID, py::function typeCaster, bool replace) {
PyGlobals::get().registerTypeCaster(mlirTypeID, std::move(typeCaster),
replace);
[](MlirTypeID mlirTypeID, bool replace) -> py::cpp_function {
return py::cpp_function([mlirTypeID,
replace](py::object typeCaster) -> py::object {
PyGlobals::get().registerTypeCaster(mlirTypeID, typeCaster, replace);
return typeCaster;
});
},
"typeid"_a, "type_caster"_a, "replace"_a = false,
"typeid"_a, py::kw_only(), "replace"_a = false,
"Register a type caster for casting MLIR types to custom user types.");
m.def(
MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR,
[](MlirTypeID mlirTypeID, bool replace) -> py::cpp_function {
return py::cpp_function(
[mlirTypeID, replace](py::object valueCaster) -> py::object {
PyGlobals::get().registerValueCaster(mlirTypeID, valueCaster,
replace);
return valueCaster;
});
},
"typeid"_a, py::kw_only(), "replace"_a = false,
"Register a value caster for casting MLIR values to custom user values.");

// Define and populate IR submodule.
auto irModule = m.def_submodule("ir", "MLIR IR Bindings");
Expand Down
15 changes: 13 additions & 2 deletions mlir/lib/Bindings/Python/PybindUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H

#include "mlir-c/Support.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/DataTypes.h"

Expand Down Expand Up @@ -228,6 +229,11 @@ class Sliceable {
return linearIndex;
}

/// Trait to check if T provides a `maybeDownCast` method.
/// Note, you need the & to detect inherited members.
template <typename T, typename... Args>
using has_maybe_downcast = decltype(&T::maybeDownCast);

/// Returns the element at the given slice index. Supports negative indices
/// by taking elements in inverse order. Returns a nullptr object if out
/// of bounds.
Expand All @@ -239,8 +245,13 @@ class Sliceable {
return {};
}

return pybind11::cast(
static_cast<Derived *>(this)->getRawElement(linearizeIndex(index)));
if constexpr (llvm::is_detected<has_maybe_downcast, ElementTy>::value)
return static_cast<Derived *>(this)
->getRawElement(linearizeIndex(index))
.maybeDownCast();
else
return pybind11::cast(
static_cast<Derived *>(this)->getRawElement(linearizeIndex(index)));
}

/// Returns a new instance of the pseudo-container restricted to the given
Expand Down
13 changes: 12 additions & 1 deletion mlir/python/mlir/dialects/_ods_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
# Provide a convenient name for sub-packages to resolve the main C-extension
# with a relative import.
from .._mlir_libs import _mlir as _cext
from typing import Sequence as _Sequence, Union as _Union
from typing import (
Sequence as _Sequence,
Type as _Type,
TypeVar as _TypeVar,
Union as _Union,
)

__all__ = [
"equally_sized_accessor",
Expand Down Expand Up @@ -123,3 +128,9 @@ def get_op_result_or_op_results(
if len(op.results) > 0
else op
)


# This is the standard way to indicate subclass/inheritance relationship
# see the typing.Type doc string.
_U = _TypeVar("_U", bound=_cext.ir.Value)
SubClassValueT = _Type[_U]
2 changes: 1 addition & 1 deletion mlir/python/mlir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from ._mlir_libs._mlir.ir import *
from ._mlir_libs._mlir.ir import _GlobalDebug
from ._mlir_libs._mlir import register_type_caster
from ._mlir_libs._mlir import register_type_caster, register_value_caster


# Convenience decorator for registering user-friendly Attribute builders.
Expand Down
Loading

0 comments on commit 7c85086

Please sign in to comment.