Skip to content

Commit

Permalink
[mlir][python] Directly implement sequence protocol on Sliceable.
Browse files Browse the repository at this point in the history
* While annoying, this is the only way to get C++ exception handling out of the happy path for normal iteration.
* Implements sq_length and sq_item for the sequence protocol (used for iteration, including list() construction).
* Implements mp_subscript for general use (i.e. foo[1] and foo[1:1]).
* For constructing a `list(op.results)`, this reduces the time from ~4-5us to ~1.5us on my machine (give or take measurement overhead) and eliminates C++ exceptions, which is a worthy goal in itself.
  * Compared to a baseline of similar construction of a three-integer list, which takes 450ns (might just be measuring function call overhead).
  * See issue discussed on the pybind side: pybind/pybind11#2842

Differential Revision: https://reviews.llvm.org/D119691
  • Loading branch information
stellaraccident committed Feb 14, 2022
1 parent e404e22 commit 429b0cf
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 31 deletions.
102 changes: 74 additions & 28 deletions mlir/lib/Bindings/Python/PybindUtils.h
Expand Up @@ -207,6 +207,8 @@ struct PySinglePartStringAccumulator {
/// constructs a new instance of the derived pseudo-container with the
/// given slice parameters (to be forwarded to the Sliceable constructor).
///
/// The getNumElements() and getElement(intptr_t) callbacks must not throw.
///
/// A derived class may additionally define:
/// - a `static void bindDerived(ClassTy &)` method to bind additional methods
/// the python class.
Expand All @@ -215,49 +217,53 @@ class Sliceable {
protected:
using ClassTy = pybind11::class_<Derived>;

// Transforms `index` into a legal value to access the underlying sequence.
// Returns <0 on failure.
intptr_t wrapIndex(intptr_t index) {
if (index < 0)
index = length + index;
if (index < 0 || index >= length) {
throw python::SetPyError(PyExc_IndexError,
"attempt to access out of bounds");
}
if (index < 0 || index >= length)
return -1;
return index;
}

public:
explicit Sliceable(intptr_t startIndex, intptr_t length, intptr_t step)
: startIndex(startIndex), length(length), step(step) {
assert(length >= 0 && "expected non-negative slice length");
}

/// Returns the length of the slice.
intptr_t dunderLen() const { return length; }

/// Returns the element at the given slice index. Supports negative indices
/// by taking elements in inverse order. Throws if the index is out of bounds.
ElementTy dunderGetItem(intptr_t index) {
/// by taking elements in inverse order. Returns a nullptr object if out
/// of bounds.
pybind11::object getItem(intptr_t index) {
// Negative indices mean we count from the end.
index = wrapIndex(index);
if (index < 0) {
PyErr_SetString(PyExc_IndexError, "index out of range");
return {};
}

// Compute the linear index given the current slice properties.
int linearIndex = index * step + startIndex;
assert(linearIndex >= 0 &&
linearIndex < static_cast<Derived *>(this)->getNumElements() &&
"linear index out of bounds, the slice is ill-formed");
return static_cast<Derived *>(this)->getElement(linearIndex);
return pybind11::cast(
static_cast<Derived *>(this)->getElement(linearIndex));
}

/// Returns a new instance of the pseudo-container restricted to the given
/// slice.
Derived dunderGetItemSlice(pybind11::slice slice) {
/// slice. Returns a nullptr object on failure.
pybind11::object getItemSlice(PyObject *slice) {
ssize_t start, stop, extraStep, sliceLength;
if (!slice.compute(dunderLen(), &start, &stop, &extraStep, &sliceLength)) {
throw python::SetPyError(PyExc_IndexError,
"attempt to access out of bounds");
if (PySlice_GetIndicesEx(slice, length, &start, &stop, &extraStep,
&sliceLength) != 0) {
PyErr_SetString(PyExc_IndexError, "index out of range");
return {};
}
return static_cast<Derived *>(this)->slice(startIndex + start * step,
sliceLength, step * extraStep);
return pybind11::cast(static_cast<Derived *>(this)->slice(
startIndex + start * step, sliceLength, step * extraStep));
}

public:
explicit Sliceable(intptr_t startIndex, intptr_t length, intptr_t step)
: startIndex(startIndex), length(length), step(step) {
assert(length >= 0 && "expected non-negative slice length");
}

/// Returns a new vector (mapped to Python list) containing elements from two
Expand All @@ -267,10 +273,10 @@ class Sliceable {
std::vector<ElementTy> elements;
elements.reserve(length + other.length);
for (intptr_t i = 0; i < length; ++i) {
elements.push_back(dunderGetItem(i));
elements.push_back(static_cast<Derived *>(this)->getElement(i));
}
for (intptr_t i = 0; i < other.length; ++i) {
elements.push_back(other.dunderGetItem(i));
elements.push_back(static_cast<Derived *>(this)->getElement(i));
}
return elements;
}
Expand All @@ -279,11 +285,51 @@ class Sliceable {
static void bind(pybind11::module &m) {
auto clazz = pybind11::class_<Derived>(m, Derived::pyClassName,
pybind11::module_local())
.def("__len__", &Sliceable::dunderLen)
.def("__getitem__", &Sliceable::dunderGetItem)
.def("__getitem__", &Sliceable::dunderGetItemSlice)
.def("__add__", &Sliceable::dunderAdd);
Derived::bindDerived(clazz);

// Manually implement the sequence protocol via the C API. We do this
// because it is approx 4x faster than via pybind11, largely because that
// formulation requires a C++ exception to be thrown to detect end of
// sequence.
// Since we are in a C-context, any C++ exception that happens here
// will terminate the program. There is nothing in this implementation
// that should throw in a non-terminal way, so we forgo further
// exception marshalling.
// See: https://github.com/pybind/pybind11/issues/2842
auto heap_type = reinterpret_cast<PyHeapTypeObject *>(clazz.ptr());
assert(heap_type->ht_type.tp_flags & Py_TPFLAGS_HEAPTYPE &&
"must be heap type");
heap_type->as_sequence.sq_length = +[](PyObject *rawSelf) -> Py_ssize_t {
auto self = pybind11::cast<Derived *>(rawSelf);
return self->length;
};
// sq_item is called as part of the sequence protocol for iteration,
// list construction, etc.
heap_type->as_sequence.sq_item =
+[](PyObject *rawSelf, Py_ssize_t index) -> PyObject * {
auto self = pybind11::cast<Derived *>(rawSelf);
return self->getItem(index).release().ptr();
};
// mp_subscript is used for both slices and integer lookups.
heap_type->as_mapping.mp_subscript =
+[](PyObject *rawSelf, PyObject *rawSubscript) -> PyObject * {
auto self = pybind11::cast<Derived *>(rawSelf);
Py_ssize_t index = PyNumber_AsSsize_t(rawSubscript, PyExc_IndexError);
if (!PyErr_Occurred()) {
// Integer indexing.
return self->getItem(index).release().ptr();
}
PyErr_Clear();

// Assume slice-based indexing.
if (PySlice_Check(rawSubscript)) {
return self->getItemSlice(rawSubscript).release().ptr();
}

PyErr_SetString(PyExc_ValueError, "expected integer or slice");
return nullptr;
};
}

/// Hook for derived classes willing to bind more methods.
Expand Down
14 changes: 11 additions & 3 deletions mlir/test/python/ir/operation.py
Expand Up @@ -14,6 +14,14 @@ def run(f):
return f


def expect_index_error(callback):
try:
_ = callback()
raise RuntimeError("Expected IndexError")
except IndexError:
pass


# Verify iterator based traversal of the op/region/block hierarchy.
# CHECK-LABEL: TEST: testTraverseOpRegionBlockIterators
@run
Expand Down Expand Up @@ -418,7 +426,9 @@ def testOperationResultList():
for t in call.results.types:
print(f"Result type {t}")


# Out of range
expect_index_error(lambda: call.results[3])
expect_index_error(lambda: call.results[-4])


# CHECK-LABEL: TEST: testOperationResultListSlice
Expand Down Expand Up @@ -470,8 +480,6 @@ def testOperationResultListSlice():
print(f"Result {res.result_number}, type {res.type}")




# CHECK-LABEL: TEST: testOperationAttributes
@run
def testOperationAttributes():
Expand Down

0 comments on commit 429b0cf

Please sign in to comment.