Skip to content

Commit

Permalink
Add support for slice inputs to custom sequence return types
Browse files Browse the repository at this point in the history
This commit adds support for dealing with slice inputs to custom
sequence return types such as NodeIndices. Previously, if a slice
was requested from a custom return type it would have raised a
TypeError. With this commit now it will return a new container object
with a copy of the data from the requested slice.

Fixes Qiskit#590
  • Loading branch information
mtreinish committed Apr 21, 2022
1 parent add8e5b commit 293a6c7
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 10 deletions.
21 changes: 21 additions & 0 deletions releasenotes/notes/slices-for-sequences-a3b31c70f5d896b4.yaml
@@ -0,0 +1,21 @@
---
fixes:
- |
The custom sequence return classes:
* :class:`~.BFSSSuccessors`
* :class:`~.NodeIndices`
* :class:`~.EdgeList`
* :class:`~WeightedEdgeList`
* :class:`~EdgeIndices`
* :class:`~Chains`
now correctly handle slice inputs to ``__getitem__``. Previously if you
tried to access a slice from one of these objects it would raise a
``TypeError. For example, if you had a :class:`~.NodeIndices` object named
``nodes`` containing ``[0, 1, 3, 4, 5]`` if you did something like::
nodes[0:3]
it would return a new :class:`~.NodeIndices` object containing ``[0, 1, 3]``
Fixed `#590 <https://github.com/Qiskit/retworkx/issues/590>`__
41 changes: 33 additions & 8 deletions src/iterators.rs
Expand Up @@ -48,6 +48,7 @@ use pyo3::class::iter::IterNextOutput;
use pyo3::exceptions::{PyIndexError, PyKeyError, PyNotImplementedError};
use pyo3::gc::PyVisit;
use pyo3::prelude::*;
use pyo3::types::PySlice;
use pyo3::PyTraverseError;

macro_rules! last_type {
Expand Down Expand Up @@ -406,6 +407,12 @@ trait PyGCProtocol {
fn __clear__(&mut self) {}
}

#[derive(FromPyObject)]
enum SliceOrInt<'a> {
Slice(&'a PySlice),
Int(isize),
}

macro_rules! custom_vec_iter_impl {
($name:ident, $data:ident, $T:ty, $doc:literal) => {
#[doc = $doc]
Expand Down Expand Up @@ -471,14 +478,32 @@ macro_rules! custom_vec_iter_impl {
Ok(self.$data.len())
}

fn __getitem__(&self, idx: isize) -> PyResult<$T> {
if idx.abs() >= self.$data.len().try_into().unwrap() {
Err(PyIndexError::new_err(format!("Invalid index, {}", idx)))
} else if idx < 0 {
let len = self.$data.len();
Ok(self.$data[len - idx.abs() as usize].clone())
} else {
Ok(self.$data[idx as usize].clone())
fn __getitem__(&self, py: Python, idx: SliceOrInt) -> PyResult<PyObject> {
match idx {
SliceOrInt::Slice(slc) => {
let len: i64 = self.$data.len().try_into().unwrap();
let indices = slc.indices(len)?;
let start: usize = indices.start.try_into().unwrap();
let stop: usize = indices.stop.try_into().unwrap();
let step: usize = indices.step.try_into().unwrap();
let return_vec = $name {
$data: (start..stop)
.step_by(step)
.map(|i| self.$data[i].clone())
.collect(),
};
Ok(return_vec.into_py(py))
}
SliceOrInt::Int(idx) => {
if idx.abs() >= self.$data.len().try_into().unwrap() {
Err(PyIndexError::new_err(format!("Invalid index, {}", idx)))
} else if idx < 0 {
let len = self.$data.len();
Ok(self.$data[len - idx.abs() as usize].clone().into_py(py))
} else {
Ok(self.$data[idx as usize].clone().into_py(py))
}
}
}
}

Expand Down
39 changes: 37 additions & 2 deletions tests/test_custom_return_types.py
Expand Up @@ -20,8 +20,8 @@
class TestBFSSuccessorsComparisons(unittest.TestCase):
def setUp(self):
self.dag = retworkx.PyDAG()
node_a = self.dag.add_node("a")
self.dag.add_child(node_a, "b", "Edgy")
self.node_a = self.dag.add_node("a")
self.node_b = self.dag.add_child(self.node_a, "b", "Edgy")

def test__eq__match(self):
self.assertTrue(retworkx.bfs_successors(self.dag, 0) == [("a", ["b"])])
Expand Down Expand Up @@ -87,6 +87,13 @@ def test_hash_invalid_type(self):
with self.assertRaises(TypeError):
hash(res)

def test_slices(self):
self.dag.add_child(self.node_a, "c", "New edge")
self.dag.add_child(self.node_b, "d", "New edge to d")
successors = retworkx.bfs_successors(self.dag, 0)
slice_return = successors[0:3:2]
self.assertEqual([("a", ["c", "b"])], slice_return)


class TestNodeIndicesComparisons(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -146,6 +153,14 @@ def test_hash(self):
# Assert hash is stable
self.assertEqual(hash_res, hash(res))

def test_slices(self):
self.dag.add_node("new")
self.dag.add_node("fun")
nodes = self.dag.node_indices()
slice_return = nodes[0:3:2]
self.assertEqual([0, 2], slice_return)
self.assertEqual(nodes[0:-1], [0, 1, 2])


class TestNodesCountMapping(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -308,6 +323,12 @@ def test_hash(self):
# Assert hash is stable
self.assertEqual(hash_res, hash(res))

def test_slices(self):
self.dag.add_edge(0, 1, None)
edges = self.dag.edge_indices()
slice_return = edges[0:-1]
self.assertEqual([0, 1], slice_return)


class TestEdgeListComparisons(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -365,6 +386,13 @@ def test_hash(self):
# Assert hash is stable
self.assertEqual(hash_res, hash(res))

def test_slice(self):
self.dag.add_edge(0, 1, None)
self.dag.add_edge(0, 1, None)
edges = self.dag.edge_list()
slice_return = edges[0:3:2]
self.assertEqual([(0, 1), (0, 1)], slice_return)


class TestWeightedEdgeListComparisons(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -428,6 +456,13 @@ def test_hash_invalid_type(self):
with self.assertRaises(TypeError):
hash(res)

def test_slice(self):
self.dag.add_edge(0, 1, None)
self.dag.add_edge(0, 1, None)
edges = self.dag.weighted_edge_list()
slice_return = edges[0:3:2]
self.assertEqual([(0, 1, "Edgy"), (0, 1, None)], slice_return)


class TestPathMapping(unittest.TestCase):
def setUp(self):
Expand Down

0 comments on commit 293a6c7

Please sign in to comment.