Skip to content

Commit

Permalink
some cleanup, alter bounded::channel_from_iter API again (slightly)
Browse files Browse the repository at this point in the history
  • Loading branch information
Oliver Giersch committed Mar 12, 2024
1 parent af1b996 commit 744913d
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 30 deletions.
32 changes: 20 additions & 12 deletions src/bounded.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,14 @@ pub const fn channel<T>(capacity: usize) -> Channel<T> {
Channel { queue: BoundedQueue::new(capacity) }
}

/// Returns a new bounded channel with pre-queued elements.
/// Returns a new bounded channel with pre-queued items.
///
/// The initial capacity will be the difference between `capacity` and the
/// number of elements returned by the [`Iterator`].
/// The iterator may return more than `capacity` elements, but the channel's
/// capacity will never exceed the given `capacity`.
///
/// # Panics
///
/// Panics, if `capacity` is zero.
pub fn channel_from_iter<T>(capacity: usize, iter: impl IntoIterator<Item = T>) -> Channel<T> {
Channel::from_iter(capacity, iter)
/// The channel's (total) capacity will be the maximum of `minimum_capacity` and
/// the number of items returned by `iter`.
/// Its initial available capacity will be the difference between its total
/// capacity and the number of pre-queued items.
pub fn channel_from_iter<T>(min_capacity: usize, iter: impl IntoIterator<Item = T>) -> Channel<T> {
Channel::from_iter(min_capacity, iter)
}

/// An unsynchronized (`!Sync`), asynchronous and bounded channel.
Expand Down Expand Up @@ -955,6 +951,12 @@ mod tests {

use crate::queue::RecvFuture;

#[test]
#[should_panic]
fn channel_panic() {
let _ = super::channel::<i32>(0);
}

#[test]
fn recv_split() {
future::block_on(async {
Expand Down Expand Up @@ -1283,7 +1285,13 @@ mod tests {
}

#[test]
fn from_iter() {
fn from_iter_less() {
let chan = super::channel_from_iter(0, &[0, 1, 2, 3]);
assert_eq!(chan.capacity(), 0);
}

#[test]
fn from_iter_more() {
future::block_on(async {
let chan = super::Channel::from_iter(5, [0, 1, 2, 3]);
assert_eq!(chan.recv().await, Some(0));
Expand Down
8 changes: 6 additions & 2 deletions src/queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,10 @@ impl<T> UnsyncQueue<T, Bounded> {
}

pub(crate) fn from_iter(capacity: usize, iter: impl IntoIterator<Item = T>) -> Self {
assert_capacity(capacity);
let queue = VecDeque::from_iter(iter);
let capacity = cmp::max(queue.len(), capacity);
let initial_capacity = capacity.saturating_sub(queue.len());
let initial_capacity = capacity - queue.len();

Self(UnsafeCell::new(Queue::new(
queue,
Bounded {
Expand Down Expand Up @@ -307,6 +307,7 @@ impl<T, B> Queue<T, B>
where
Self: MaybeBoundedQueue<Item = T>,
{
#[cold]
pub(crate) fn set_counted(&mut self) {
self.reset();
self.waker = None;
Expand Down Expand Up @@ -338,6 +339,7 @@ impl<T> MaybeBoundedQueue for Queue<T, Unbounded> {

fn reset(&mut self) {}

#[cold]
fn close<const COUNTED: bool>(&mut self) {
self.mask.close::<COUNTED>();
}
Expand All @@ -357,13 +359,15 @@ impl<T> MaybeBoundedQueue for Queue<T, Unbounded> {
impl<T> MaybeBoundedQueue for Queue<T, Bounded> {
type Item = T;

#[cold]
fn reset(&mut self) {
// this can never underflow, because `permits` is never increased above
// the specified `max_capacity`
let diff = self.extra.max_capacity - self.extra.semaphore.available_permits();
self.extra.semaphore.add_permits(diff);
}

#[cold]
fn close<const COUNTED: bool>(&mut self) {
// must also close semaphore in order to notify all waiting senders
self.mask.close::<COUNTED>();
Expand Down
43 changes: 27 additions & 16 deletions src/semaphore.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,18 +112,7 @@ impl Semaphore {
/// Returns an correctly initialized [`Acquire`] future instance for
/// acquiring `wants` permits.
fn build_acquire(&self, wants: usize) -> Acquire<'_> {
Acquire {
shared: &self.shared,
waiter: Waiter {
wants,
waker: LateInitWaker::new(),
state: Cell::new(WaiterState::Inert),
permits: Cell::new(0),
next: Cell::new(ptr::null()),
prev: Cell::new(ptr::null()),
_marker: PhantomPinned,
},
}
Acquire { shared: &self.shared, waiter: Waiter::new(wants) }
}
}

Expand Down Expand Up @@ -232,6 +221,7 @@ pub struct Acquire<'a> {
impl<'a> Future for Acquire<'a> {
type Output = Result<Permit<'a>, AcquireError>;

#[inline]
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
// SAFETY: The `Acquire` future can not be moved before being dropped
let waiter = unsafe { Pin::map_unchecked(self.as_ref(), |acquire| &acquire.waiter) };
Expand Down Expand Up @@ -292,6 +282,7 @@ struct Shared {

impl Shared {
/// Closes the semaphore and notifies all remaining waiters.
#[cold]
fn close(&mut self) -> usize {
// SAFETY: non-live waiters di not exist in queue, no aliased access
// possible
Expand Down Expand Up @@ -415,10 +406,10 @@ impl Shared {
waiter.waker.set(cx.waker().clone());
// SAFETY: All waiters remain valid while they are enqueued.
//
// Each `Acquire` future contains/owns a `Waiter` and may either live on
// the stack or the heap.
// Each future must be pinned before it can be polled and therefore both
// the future and the waiter will remain in-place for their entire
// Each `Acquire` future contains (owns) a `Waiter` and may either live
// on the stack or the heap.
// Each future *must* be pinned before it can be polled and therefore
// both the future and the waiter will remain in-place for their entire
// lifetime.
// When the future/waiter are cancelled or dropped, they will dequeue
// themselves to ensure no iteration over freed data is possible.
Expand Down Expand Up @@ -455,7 +446,10 @@ impl WaiterQueue {
/// # Safety
///
/// All pointers must reference valid, live and non-aliased `Waiter`s.
#[cold]
unsafe fn len(&self) -> usize {
// this is only used in the [`Debug`] implementation, so counting each
// waiter one by one here is irrelevant to performance
let mut curr = self.head;
let mut waiting = 0;
while !curr.is_null() {
Expand Down Expand Up @@ -492,6 +486,7 @@ impl WaiterQueue {
/// # Safety
///
/// All pointers must reference valid, live and non-aliased `Waiter`s.
#[cold]
unsafe fn try_remove(&mut self, waiter: &Waiter) {
let prev = waiter.prev.get();
if prev.is_null() {
Expand All @@ -516,6 +511,7 @@ impl WaiterQueue {
///
/// All pointers must reference valid, live and non-aliased `Waiter`s and
/// `head` must be the current queue head.
#[inline]
unsafe fn pop_front(&mut self, head: &Waiter) {
self.head = head.next.get();
if self.head.is_null() {
Expand All @@ -525,6 +521,7 @@ impl WaiterQueue {
}
}

#[cold]
unsafe fn wake_all(&mut self) -> usize {
let mut curr = self.head;
let mut woken = 0;
Expand Down Expand Up @@ -569,6 +566,20 @@ struct Waiter {
_marker: PhantomPinned,
}

impl Waiter {
const fn new(wants: usize) -> Self {
Self {
wants,
waker: LateInitWaker::new(),
state: Cell::new(WaiterState::Inert),
permits: Cell::new(0),
next: Cell::new(ptr::null()),
prev: Cell::new(ptr::null()),
_marker: PhantomPinned,
}
}
}

/// The current state of a [`Waiter`].
#[derive(Clone, Copy)]
enum WaiterState {
Expand Down
3 changes: 3 additions & 0 deletions src/unbounded.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ impl<T> UnboundedChannel<T> {
}

/// Closes the channel, ensuring that all subsequent sends will fail.
#[cold]
pub fn close(&self) {
self.queue.close::<UNCOUNTED>();
}
Expand Down Expand Up @@ -301,6 +302,7 @@ pub struct UnboundedReceiver<T> {

impl<T> UnboundedReceiver<T> {
/// Closes the channel, ensuring that all subsequent sends will fail.
#[cold]
pub fn close(&mut self) {
self.queue.close::<COUNTED>();
}
Expand Down Expand Up @@ -371,6 +373,7 @@ pub struct UnboundedReceiverRef<'a, T> {

impl<T> UnboundedReceiverRef<'_, T> {
/// Closes the channel, ensuring that all subsequent sends will fail.
#[cold]
pub fn close(&mut self) {
self.queue.close::<COUNTED>();
}
Expand Down

0 comments on commit 744913d

Please sign in to comment.