diff --git a/tests/test_utils.rs b/tests/test_utils.rs index 8194590..3c4ef4b 100644 --- a/tests/test_utils.rs +++ b/tests/test_utils.rs @@ -102,35 +102,47 @@ pub struct TestIO { } #[derive(Default)] -pub struct CloseableCursor { - data: RwLock>, - cursor: RwLock, - waker: RwLock>, - closed: RwLock, +struct CloseableCursorInner { + data: Vec, + cursor: usize, + waker: Option, + closed: bool, } +#[derive(Default)] +pub struct CloseableCursor(RwLock); + impl CloseableCursor { - fn len(&self) -> usize { - self.data.read().unwrap().len() + pub fn len(&self) -> usize { + self.0.read().unwrap().data.len() + } + + pub fn cursor(&self) -> usize { + self.0.read().unwrap().cursor } - fn cursor(&self) -> usize { - *self.cursor.read().unwrap() + pub fn is_empty(&self) -> bool { + self.len() == 0 } - fn current(&self) -> bool { - self.len() == self.cursor() + pub fn current(&self) -> bool { + let inner = self.0.read().unwrap(); + inner.data.len() == inner.cursor } - fn close(&self) { - *self.closed.write().unwrap() = true; + pub fn close(&self) { + let mut inner = self.0.write().unwrap(); + inner.closed = true; + if let Some(waker) = inner.waker.take() { + waker.wake(); + } } } impl Display for CloseableCursor { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let data = &*self.data.read().unwrap(); - let s = std::str::from_utf8(data).unwrap_or("not utf8"); + let inner = self.0.read().unwrap(); + let s = std::str::from_utf8(&inner.data).unwrap_or("not utf8"); write!(f, "{}", s) } } @@ -163,13 +175,14 @@ impl TestIO { impl Debug for CloseableCursor { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let inner = self.0.read().unwrap(); f.debug_struct("CloseableCursor") .field( "data", - &std::str::from_utf8(&self.data.read().unwrap()).unwrap_or("not utf8"), + &std::str::from_utf8(&inner.data).unwrap_or("not utf8"), ) - .field("closed", &*self.closed.read().unwrap()) - .field("cursor", &*self.cursor.read().unwrap()) + .field("closed", &inner.closed) + .field("cursor", &inner.cursor) .finish() } } @@ -180,18 +193,17 @@ impl Read for &CloseableCursor { cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll> { - let len = self.len(); - let cursor = self.cursor(); - if cursor < len { - let data = &*self.data.read().unwrap(); - let bytes_to_copy = buf.len().min(len - cursor); - buf[..bytes_to_copy].copy_from_slice(&data[cursor..cursor + bytes_to_copy]); - *self.cursor.write().unwrap() += bytes_to_copy; + let mut inner = self.0.write().unwrap(); + if inner.cursor < inner.data.len() { + let bytes_to_copy = buf.len().min(inner.data.len() - inner.cursor); + buf[..bytes_to_copy] + .copy_from_slice(&inner.data[inner.cursor..inner.cursor + bytes_to_copy]); + inner.cursor += bytes_to_copy; Poll::Ready(Ok(bytes_to_copy)) - } else if *self.closed.read().unwrap() { + } else if inner.closed { Poll::Ready(Ok(0)) } else { - *self.waker.write().unwrap() = Some(cx.waker().clone()); + inner.waker = Some(cx.waker().clone()); Poll::Pending } } @@ -203,11 +215,12 @@ impl Write for &CloseableCursor { _cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - if *self.closed.read().unwrap() { + let mut inner = self.0.write().unwrap(); + if inner.closed { Poll::Ready(Ok(0)) } else { - self.data.write().unwrap().extend_from_slice(buf); - if let Some(waker) = self.waker.write().unwrap().take() { + inner.data.extend_from_slice(buf); + if let Some(waker) = inner.waker.take() { waker.wake(); } Poll::Ready(Ok(buf.len())) @@ -219,10 +232,7 @@ impl Write for &CloseableCursor { } fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - if let Some(waker) = self.waker.write().unwrap().take() { - waker.wake(); - } - *self.closed.write().unwrap() = true; + self.close(); Poll::Ready(Ok(())) } }