diff --git a/src/iterators.rs b/src/iterators.rs index 1486f1ea96..785a3e62d7 100644 --- a/src/iterators.rs +++ b/src/iterators.rs @@ -49,6 +49,7 @@ use pyo3::exceptions::{PyIndexError, PyKeyError, PyNotImplementedError}; use pyo3::gc::PyVisit; use pyo3::prelude::*; use pyo3::PyTraverseError; +use pyo3::types::PySlice; macro_rules! last_type { ($a:ident,) => { $a }; @@ -406,8 +407,15 @@ trait PyGCProtocol { fn __clear__(&mut self) {} } +#[derive(FromPyObject)] +enum SliceOrInt<'a> { + Slice(&'a PySlice), + Int(isize), +} + macro_rules! custom_vec_iter_impl { ($name:ident, $data:ident, $T:ty, $doc:literal) => { + #[doc = $doc] #[pyclass(module = "retworkx")] #[derive(Clone)] @@ -471,14 +479,29 @@ macro_rules! custom_vec_iter_impl { Ok(self.$data.len()) } - fn __getitem__(&self, idx: isize) -> PyResult<$T> { - if idx.abs() >= self.$data.len().try_into().unwrap() { - Err(PyIndexError::new_err(format!("Invalid index, {}", idx))) - } else if idx < 0 { - let len = self.$data.len(); - Ok(self.$data[len - idx.abs() as usize].clone()) - } else { - Ok(self.$data[idx as usize].clone()) + fn __getitem__(&self, py: Python, idx: SliceOrInt) -> PyResult { + match idx { + SliceOrInt::Slice(slc) => { + let len: i64 = self.$data.len().try_into().unwrap(); + let indices = slc.indices(len)?; + let start: usize = indices.start.try_into().unwrap(); + let stop: usize = indices.stop.try_into().unwrap(); + let step: usize = indices.step.try_into().unwrap(); + let return_vec = $name { + $data: (start..stop).step_by(step).map(|i| self.$data[i].clone()).collect(), + }; + Ok(return_vec.into_py(py)) + }, + SliceOrInt::Int(idx) => { + if idx.abs() >= self.$data.len().try_into().unwrap() { + Err(PyIndexError::new_err(format!("Invalid index, {}", idx))) + } else if idx < 0 { + let len = self.$data.len(); + Ok(self.$data[len - idx.abs() as usize].clone().into_py(py)) + } else { + Ok(self.$data[idx as usize].clone().into_py(py)) + } + } } }