From 293a6c789ad54368ee561034be8be21cd4bcd9ee Mon Sep 17 00:00:00 2001 From: Matthew Treinish Date: Thu, 21 Apr 2022 07:35:28 -0400 Subject: [PATCH] Add support for slice inputs to custom sequence return types This commit adds support for dealing with slice inputs to custom sequence return types such as NodeIndices. Previously, if a slice was requested from a custom return type it would have raised a TypeError. With this commit now it will return a new container object with a copy of the data from the requested slice. Fixes #590 --- ...slices-for-sequences-a3b31c70f5d896b4.yaml | 21 ++++++++++ src/iterators.rs | 41 +++++++++++++++---- tests/test_custom_return_types.py | 39 +++++++++++++++++- 3 files changed, 91 insertions(+), 10 deletions(-) create mode 100644 releasenotes/notes/slices-for-sequences-a3b31c70f5d896b4.yaml diff --git a/releasenotes/notes/slices-for-sequences-a3b31c70f5d896b4.yaml b/releasenotes/notes/slices-for-sequences-a3b31c70f5d896b4.yaml new file mode 100644 index 000000000..13f113e98 --- /dev/null +++ b/releasenotes/notes/slices-for-sequences-a3b31c70f5d896b4.yaml @@ -0,0 +1,21 @@ +--- +fixes: + - | + The custom sequence return classes: + + * :class:`~.BFSSSuccessors` + * :class:`~.NodeIndices` + * :class:`~.EdgeList` + * :class:`~WeightedEdgeList` + * :class:`~EdgeIndices` + * :class:`~Chains` + + now correctly handle slice inputs to ``__getitem__``. Previously if you + tried to access a slice from one of these objects it would raise a + ``TypeError. For example, if you had a :class:`~.NodeIndices` object named + ``nodes`` containing ``[0, 1, 3, 4, 5]`` if you did something like:: + + nodes[0:3] + + it would return a new :class:`~.NodeIndices` object containing ``[0, 1, 3]`` + Fixed `#590 `__ diff --git a/src/iterators.rs b/src/iterators.rs index 1486f1ea9..e383ac2ba 100644 --- a/src/iterators.rs +++ b/src/iterators.rs @@ -48,6 +48,7 @@ use pyo3::class::iter::IterNextOutput; use pyo3::exceptions::{PyIndexError, PyKeyError, PyNotImplementedError}; use pyo3::gc::PyVisit; use pyo3::prelude::*; +use pyo3::types::PySlice; use pyo3::PyTraverseError; macro_rules! last_type { @@ -406,6 +407,12 @@ trait PyGCProtocol { fn __clear__(&mut self) {} } +#[derive(FromPyObject)] +enum SliceOrInt<'a> { + Slice(&'a PySlice), + Int(isize), +} + macro_rules! custom_vec_iter_impl { ($name:ident, $data:ident, $T:ty, $doc:literal) => { #[doc = $doc] @@ -471,14 +478,32 @@ macro_rules! custom_vec_iter_impl { Ok(self.$data.len()) } - fn __getitem__(&self, idx: isize) -> PyResult<$T> { - if idx.abs() >= self.$data.len().try_into().unwrap() { - Err(PyIndexError::new_err(format!("Invalid index, {}", idx))) - } else if idx < 0 { - let len = self.$data.len(); - Ok(self.$data[len - idx.abs() as usize].clone()) - } else { - Ok(self.$data[idx as usize].clone()) + fn __getitem__(&self, py: Python, idx: SliceOrInt) -> PyResult { + match idx { + SliceOrInt::Slice(slc) => { + let len: i64 = self.$data.len().try_into().unwrap(); + let indices = slc.indices(len)?; + let start: usize = indices.start.try_into().unwrap(); + let stop: usize = indices.stop.try_into().unwrap(); + let step: usize = indices.step.try_into().unwrap(); + let return_vec = $name { + $data: (start..stop) + .step_by(step) + .map(|i| self.$data[i].clone()) + .collect(), + }; + Ok(return_vec.into_py(py)) + } + SliceOrInt::Int(idx) => { + if idx.abs() >= self.$data.len().try_into().unwrap() { + Err(PyIndexError::new_err(format!("Invalid index, {}", idx))) + } else if idx < 0 { + let len = self.$data.len(); + Ok(self.$data[len - idx.abs() as usize].clone().into_py(py)) + } else { + Ok(self.$data[idx as usize].clone().into_py(py)) + } + } } } diff --git a/tests/test_custom_return_types.py b/tests/test_custom_return_types.py index ed52a42e2..441a8110b 100644 --- a/tests/test_custom_return_types.py +++ b/tests/test_custom_return_types.py @@ -20,8 +20,8 @@ class TestBFSSuccessorsComparisons(unittest.TestCase): def setUp(self): self.dag = retworkx.PyDAG() - node_a = self.dag.add_node("a") - self.dag.add_child(node_a, "b", "Edgy") + self.node_a = self.dag.add_node("a") + self.node_b = self.dag.add_child(self.node_a, "b", "Edgy") def test__eq__match(self): self.assertTrue(retworkx.bfs_successors(self.dag, 0) == [("a", ["b"])]) @@ -87,6 +87,13 @@ def test_hash_invalid_type(self): with self.assertRaises(TypeError): hash(res) + def test_slices(self): + self.dag.add_child(self.node_a, "c", "New edge") + self.dag.add_child(self.node_b, "d", "New edge to d") + successors = retworkx.bfs_successors(self.dag, 0) + slice_return = successors[0:3:2] + self.assertEqual([("a", ["c", "b"])], slice_return) + class TestNodeIndicesComparisons(unittest.TestCase): def setUp(self): @@ -146,6 +153,14 @@ def test_hash(self): # Assert hash is stable self.assertEqual(hash_res, hash(res)) + def test_slices(self): + self.dag.add_node("new") + self.dag.add_node("fun") + nodes = self.dag.node_indices() + slice_return = nodes[0:3:2] + self.assertEqual([0, 2], slice_return) + self.assertEqual(nodes[0:-1], [0, 1, 2]) + class TestNodesCountMapping(unittest.TestCase): def setUp(self): @@ -308,6 +323,12 @@ def test_hash(self): # Assert hash is stable self.assertEqual(hash_res, hash(res)) + def test_slices(self): + self.dag.add_edge(0, 1, None) + edges = self.dag.edge_indices() + slice_return = edges[0:-1] + self.assertEqual([0, 1], slice_return) + class TestEdgeListComparisons(unittest.TestCase): def setUp(self): @@ -365,6 +386,13 @@ def test_hash(self): # Assert hash is stable self.assertEqual(hash_res, hash(res)) + def test_slice(self): + self.dag.add_edge(0, 1, None) + self.dag.add_edge(0, 1, None) + edges = self.dag.edge_list() + slice_return = edges[0:3:2] + self.assertEqual([(0, 1), (0, 1)], slice_return) + class TestWeightedEdgeListComparisons(unittest.TestCase): def setUp(self): @@ -428,6 +456,13 @@ def test_hash_invalid_type(self): with self.assertRaises(TypeError): hash(res) + def test_slice(self): + self.dag.add_edge(0, 1, None) + self.dag.add_edge(0, 1, None) + edges = self.dag.weighted_edge_list() + slice_return = edges[0:3:2] + self.assertEqual([(0, 1, "Edgy"), (0, 1, None)], slice_return) + class TestPathMapping(unittest.TestCase): def setUp(self):