Skip to content

Commit

Permalink
Expose rpds.Queue.
Browse files Browse the repository at this point in the history
  • Loading branch information
Julian committed Dec 16, 2023
1 parent 1202558 commit e34cdde
Show file tree
Hide file tree
Showing 4 changed files with 256 additions and 3 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "rpds-py"
version = "0.13.2"
version = "0.14.0"
edition = "2021"

[lib]
Expand Down
122 changes: 121 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ use pyo3::pyclass::CompareOp;
use pyo3::types::{PyDict, PyIterator, PyTuple, PyType};
use pyo3::{exceptions::PyKeyError, types::PyMapping};
use pyo3::{prelude::*, AsPyPointer, PyTypeInfo};
use rpds::{HashTrieMap, HashTrieMapSync, HashTrieSet, HashTrieSetSync, List, ListSync};
use rpds::{
HashTrieMap, HashTrieMapSync, HashTrieSet, HashTrieSetSync, List, ListSync, Queue, QueueSync,
};

#[derive(Clone, Debug)]
struct Key {
Expand Down Expand Up @@ -618,12 +620,130 @@ impl ListIterator {
}
}

#[repr(transparent)]
#[pyclass(name = "Queue", module = "rpds", frozen, sequence)]
struct QueuePy {
inner: QueueSync<PyObject>,
}

impl From<QueueSync<PyObject>> for QueuePy {
fn from(elements: QueueSync<PyObject>) -> Self {
QueuePy { inner: elements }
}
}

impl<'source> FromPyObject<'source> for QueuePy {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
let mut ret = Queue::new_sync();
for each in ob.iter()? {
ret.enqueue_mut(each?.extract()?);
}
Ok(QueuePy { inner: ret })
}
}

#[pymethods]
impl QueuePy {
#[new]
#[pyo3(signature = (*elements))]
fn init(elements: &PyTuple, py: Python<'_>) -> PyResult<Self> {
let mut ret: QueuePy;
if elements.len() == 1 {
ret = elements.get_item(0)?.extract()?;
} else {
ret = QueuePy {
inner: Queue::new_sync(),
};
if elements.len() > 1 {
for each in elements {
ret.inner.enqueue_mut(each.into_py(py));
}
}
}
Ok(ret)
}

fn __eq__(&self, other: &Self, py: Python<'_>) -> bool {
(self.inner.len() == other.inner.len())
&& self
.inner
.iter()
.zip(other.inner.iter())
.map(|(e1, e2)| PyAny::eq(e1.extract(py)?, e2))
.all(|r| r.unwrap_or(false))
}

fn __ne__(&self, other: &Self, py: Python<'_>) -> bool {
(self.inner.len() != other.inner.len())
|| self
.inner
.iter()
.zip(other.inner.iter())
.map(|(e1, e2)| PyAny::ne(e1.extract(py)?, e2))
.any(|r| r.unwrap_or(true))
}

fn __iter__(slf: PyRef<'_, Self>) -> PyResult<Py<ListIterator>> {
let iter = slf
.inner
.iter()
.map(|k| k.to_owned())
.collect::<Vec<_>>()
.into_iter();
Py::new(slf.py(), ListIterator { inner: iter })
}

fn __len__(&self) -> usize {
self.inner.len()
}

fn __repr__(&self, py: Python) -> String {
let contents = self.inner.into_iter().map(|k| {
k.clone()
.into_py(py)
.call_method0(py, "__repr__")
.and_then(|r| r.extract(py))
.unwrap_or("<repr failed>".to_owned())
});
format!("Queue([{}])", contents.collect::<Vec<_>>().join(", "))
}

#[getter]
fn peek(&self) -> PyResult<PyObject> {
if let Some(peeked) = self.inner.peek() {
Ok(peeked.to_owned())
} else {
Err(PyIndexError::new_err("peeked an empty queue"))
}
}

#[getter]
fn is_empty(&self) -> bool {
self.inner.is_empty()
}

fn enqueue(&self, value: &PyAny) -> Self {
QueuePy {
inner: self.inner.enqueue(value.into()),
}
}

fn dequeue(&self) -> PyResult<QueuePy> {
if let Some(inner) = self.inner.dequeue() {
Ok(QueuePy { inner })
} else {
Err(PyIndexError::new_err("dequeued an empty queue"))
}
}
}

#[pymodule]
#[pyo3(name = "rpds")]
fn rpds_py(py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<HashTrieMapPy>()?;
PyMapping::register::<HashTrieMapPy>(py)?;
m.add_class::<HashTrieSetPy>()?;
m.add_class::<ListPy>()?;
m.add_class::<QueuePy>()?;
Ok(())
}
133 changes: 133 additions & 0 deletions tests/test_queue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
"""
Modified from the pyrsistent test suite.
Pre-modification, these were MIT licensed, and are copyright:
Copyright (c) 2022 Tobias Gustafsson
Permission is hereby granted, free of charge, to any person
obtaining a copy of this software and associated documentation
files (the "Software"), to deal in the Software without
restriction, including without limitation the rights to use,
copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following
conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
OTHER DEALINGS IN THE SOFTWARE.
"""
import pytest

from rpds import Queue

HASH_MSG = "Not sure Queue implements Hash, it has mutable methods"


def test_literalish_works():
assert Queue(1, 2, 3) == Queue([1, 2, 3])


def test_peek_dequeue():
pl = Queue([1, 2])
assert pl.peek == 1
assert pl.dequeue().peek == 2
assert pl.dequeue().dequeue().is_empty
with pytest.raises(IndexError):
pl.dequeue().dequeue().dequeue()


def test_instantiate_large_list():
assert Queue(range(1000)).peek == 0


def test_iteration():
assert list(Queue()) == []
assert list(Queue([1, 2, 3])) == [1, 2, 3]


def test_enqueue():
assert Queue([1, 2, 3]).enqueue(4) == Queue([1, 2, 3, 4])


def test_enqueue_empty_list():
assert Queue().enqueue(0) == Queue([0])


def test_truthiness():
assert Queue([1])
assert not Queue()


def test_len():
assert len(Queue([1, 2, 3])) == 3
assert len(Queue()) == 0


def test_peek_illegal_on_empty_list():
with pytest.raises(IndexError):
Queue().peek


def test_inequality():
assert Queue([1, 2]) != Queue([1, 3])
assert Queue([1, 2]) != Queue([1, 2, 3])
assert Queue() != Queue([1, 2, 3])


def test_repr():
assert str(Queue()) == "Queue([])"
assert str(Queue([1, 2, 3])) in "Queue([1, 2, 3])"


@pytest.mark.xfail(reason=HASH_MSG)
def test_hashing():
assert hash(Queue([1, 2])) == hash(Queue([1, 2]))
assert hash(Queue([1, 2])) != hash(Queue([2, 1]))


def test_sequence():
m = Queue("asdf")
assert m == Queue(["a", "s", "d", "f"])


# Non-pyrsistent-test-suite tests


def test_dequeue():
assert Queue([1, 2, 3]).dequeue() == Queue([2, 3])


def test_dequeue_empty():
"""
rpds itself returns an Option<Queue> here but we try IndexError instead.
"""
with pytest.raises(IndexError):
Queue([]).dequeue()


def test_more_eq():
o = object()

assert Queue([o, o]) == Queue([o, o])
assert Queue([o]) == Queue([o])
assert Queue() == Queue([])
assert not (Queue([1, 2]) == Queue([1, 3]))
assert not (Queue([o]) == Queue([o, o]))
assert not (Queue([]) == Queue([o]))

assert Queue([1, 2]) != Queue([1, 3])
assert Queue([o]) != Queue([o, o])
assert Queue([]) != Queue([o])
assert not (Queue([o, o]) != Queue([o, o]))
assert not (Queue([o]) != Queue([o]))
assert not (Queue() != Queue([]))

0 comments on commit e34cdde

Please sign in to comment.