diff --git a/Cargo.lock b/Cargo.lock index 01eb683..048b868 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -195,7 +195,7 @@ dependencies = [ [[package]] name = "rpds-py" -version = "0.13.2" +version = "0.14.0" dependencies = [ "archery", "pyo3", diff --git a/Cargo.toml b/Cargo.toml index 9930d0a..bc8dcd2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rpds-py" -version = "0.13.2" +version = "0.14.0" edition = "2021" [lib] diff --git a/src/lib.rs b/src/lib.rs index d5b3a91..97955ee 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,7 +6,9 @@ use pyo3::pyclass::CompareOp; use pyo3::types::{PyDict, PyIterator, PyTuple, PyType}; use pyo3::{exceptions::PyKeyError, types::PyMapping}; use pyo3::{prelude::*, AsPyPointer, PyTypeInfo}; -use rpds::{HashTrieMap, HashTrieMapSync, HashTrieSet, HashTrieSetSync, List, ListSync}; +use rpds::{ + HashTrieMap, HashTrieMapSync, HashTrieSet, HashTrieSetSync, List, ListSync, Queue, QueueSync, +}; #[derive(Clone, Debug)] struct Key { @@ -618,6 +620,123 @@ impl ListIterator { } } +#[repr(transparent)] +#[pyclass(name = "Queue", module = "rpds", frozen, sequence)] +struct QueuePy { + inner: QueueSync, +} + +impl From> for QueuePy { + fn from(elements: QueueSync) -> Self { + QueuePy { inner: elements } + } +} + +impl<'source> FromPyObject<'source> for QueuePy { + fn extract(ob: &'source PyAny) -> PyResult { + let mut ret = Queue::new_sync(); + for each in ob.iter()? { + ret.enqueue_mut(each?.extract()?); + } + Ok(QueuePy { inner: ret }) + } +} + +#[pymethods] +impl QueuePy { + #[new] + #[pyo3(signature = (*elements))] + fn init(elements: &PyTuple, py: Python<'_>) -> PyResult { + let mut ret: QueuePy; + if elements.len() == 1 { + ret = elements.get_item(0)?.extract()?; + } else { + ret = QueuePy { + inner: Queue::new_sync(), + }; + if elements.len() > 1 { + for each in elements { + ret.inner.enqueue_mut(each.into_py(py)); + } + } + } + Ok(ret) + } + + fn __eq__(&self, other: &Self, py: Python<'_>) -> bool { + (self.inner.len() == other.inner.len()) + && self + .inner + .iter() + .zip(other.inner.iter()) + .map(|(e1, e2)| PyAny::eq(e1.extract(py)?, e2)) + .all(|r| r.unwrap_or(false)) + } + + fn __ne__(&self, other: &Self, py: Python<'_>) -> bool { + (self.inner.len() != other.inner.len()) + || self + .inner + .iter() + .zip(other.inner.iter()) + .map(|(e1, e2)| PyAny::ne(e1.extract(py)?, e2)) + .any(|r| r.unwrap_or(true)) + } + + fn __iter__(slf: PyRef<'_, Self>) -> PyResult> { + let iter = slf + .inner + .iter() + .map(|k| k.to_owned()) + .collect::>() + .into_iter(); + Py::new(slf.py(), ListIterator { inner: iter }) + } + + fn __len__(&self) -> usize { + self.inner.len() + } + + fn __repr__(&self, py: Python) -> String { + let contents = self.inner.into_iter().map(|k| { + k.clone() + .into_py(py) + .call_method0(py, "__repr__") + .and_then(|r| r.extract(py)) + .unwrap_or("".to_owned()) + }); + format!("Queue([{}])", contents.collect::>().join(", ")) + } + + #[getter] + fn peek(&self) -> PyResult { + if let Some(peeked) = self.inner.peek() { + Ok(peeked.to_owned()) + } else { + Err(PyIndexError::new_err("peeked an empty queue")) + } + } + + #[getter] + fn is_empty(&self) -> bool { + self.inner.is_empty() + } + + fn enqueue(&self, value: &PyAny) -> Self { + QueuePy { + inner: self.inner.enqueue(value.into()), + } + } + + fn dequeue(&self) -> PyResult { + if let Some(inner) = self.inner.dequeue() { + Ok(QueuePy { inner }) + } else { + Err(PyIndexError::new_err("dequeued an empty queue")) + } + } +} + #[pymodule] #[pyo3(name = "rpds")] fn rpds_py(py: Python, m: &PyModule) -> PyResult<()> { @@ -625,5 +744,6 @@ fn rpds_py(py: Python, m: &PyModule) -> PyResult<()> { PyMapping::register::(py)?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; Ok(()) } diff --git a/tests/test_queue.py b/tests/test_queue.py new file mode 100644 index 0000000..0ba8f81 --- /dev/null +++ b/tests/test_queue.py @@ -0,0 +1,133 @@ +""" +Modified from the pyrsistent test suite. + +Pre-modification, these were MIT licensed, and are copyright: + + Copyright (c) 2022 Tobias Gustafsson + + Permission is hereby granted, free of charge, to any person + obtaining a copy of this software and associated documentation + files (the "Software"), to deal in the Software without + restriction, including without limitation the rights to use, + copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the + Software is furnished to do so, subject to the following + conditions: + + The above copyright notice and this permission notice shall be + included in all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES + OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT + HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, + WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR + OTHER DEALINGS IN THE SOFTWARE. +""" +import pytest + +from rpds import Queue + +HASH_MSG = "Not sure Queue implements Hash, it has mutable methods" + + +def test_literalish_works(): + assert Queue(1, 2, 3) == Queue([1, 2, 3]) + + +def test_peek_dequeue(): + pl = Queue([1, 2]) + assert pl.peek == 1 + assert pl.dequeue().peek == 2 + assert pl.dequeue().dequeue().is_empty + with pytest.raises(IndexError): + pl.dequeue().dequeue().dequeue() + + +def test_instantiate_large_list(): + assert Queue(range(1000)).peek == 0 + + +def test_iteration(): + assert list(Queue()) == [] + assert list(Queue([1, 2, 3])) == [1, 2, 3] + + +def test_enqueue(): + assert Queue([1, 2, 3]).enqueue(4) == Queue([1, 2, 3, 4]) + + +def test_enqueue_empty_list(): + assert Queue().enqueue(0) == Queue([0]) + + +def test_truthiness(): + assert Queue([1]) + assert not Queue() + + +def test_len(): + assert len(Queue([1, 2, 3])) == 3 + assert len(Queue()) == 0 + + +def test_peek_illegal_on_empty_list(): + with pytest.raises(IndexError): + Queue().peek + + +def test_inequality(): + assert Queue([1, 2]) != Queue([1, 3]) + assert Queue([1, 2]) != Queue([1, 2, 3]) + assert Queue() != Queue([1, 2, 3]) + + +def test_repr(): + assert str(Queue()) == "Queue([])" + assert str(Queue([1, 2, 3])) in "Queue([1, 2, 3])" + + +@pytest.mark.xfail(reason=HASH_MSG) +def test_hashing(): + assert hash(Queue([1, 2])) == hash(Queue([1, 2])) + assert hash(Queue([1, 2])) != hash(Queue([2, 1])) + + +def test_sequence(): + m = Queue("asdf") + assert m == Queue(["a", "s", "d", "f"]) + + +# Non-pyrsistent-test-suite tests + + +def test_dequeue(): + assert Queue([1, 2, 3]).dequeue() == Queue([2, 3]) + + +def test_dequeue_empty(): + """ + rpds itself returns an Option here but we try IndexError instead. + """ + with pytest.raises(IndexError): + Queue([]).dequeue() + + +def test_more_eq(): + o = object() + + assert Queue([o, o]) == Queue([o, o]) + assert Queue([o]) == Queue([o]) + assert Queue() == Queue([]) + assert not (Queue([1, 2]) == Queue([1, 3])) + assert not (Queue([o]) == Queue([o, o])) + assert not (Queue([]) == Queue([o])) + + assert Queue([1, 2]) != Queue([1, 3]) + assert Queue([o]) != Queue([o, o]) + assert Queue([]) != Queue([o]) + assert not (Queue([o, o]) != Queue([o, o])) + assert not (Queue([o]) != Queue([o])) + assert not (Queue() != Queue([]))