Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 44 additions & 34 deletions tests/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,35 +102,47 @@ pub struct TestIO {
}

#[derive(Default)]
pub struct CloseableCursor {
data: RwLock<Vec<u8>>,
cursor: RwLock<usize>,
waker: RwLock<Option<Waker>>,
closed: RwLock<bool>,
struct CloseableCursorInner {
data: Vec<u8>,
cursor: usize,
waker: Option<Waker>,
closed: bool,
}

#[derive(Default)]
pub struct CloseableCursor(RwLock<CloseableCursorInner>);

impl CloseableCursor {
fn len(&self) -> usize {
self.data.read().unwrap().len()
pub fn len(&self) -> usize {
self.0.read().unwrap().data.len()
}

pub fn cursor(&self) -> usize {
self.0.read().unwrap().cursor
}

fn cursor(&self) -> usize {
*self.cursor.read().unwrap()
pub fn is_empty(&self) -> bool {
self.len() == 0
}

fn current(&self) -> bool {
self.len() == self.cursor()
pub fn current(&self) -> bool {
let inner = self.0.read().unwrap();
inner.data.len() == inner.cursor
}

fn close(&self) {
*self.closed.write().unwrap() = true;
pub fn close(&self) {
let mut inner = self.0.write().unwrap();
inner.closed = true;
if let Some(waker) = inner.waker.take() {
waker.wake();
}
}
}

impl Display for CloseableCursor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let data = &*self.data.read().unwrap();
let s = std::str::from_utf8(data).unwrap_or("not utf8");
let inner = self.0.read().unwrap();
let s = std::str::from_utf8(&inner.data).unwrap_or("not utf8");
write!(f, "{}", s)
}
}
Expand Down Expand Up @@ -163,13 +175,14 @@ impl TestIO {

impl Debug for CloseableCursor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let inner = self.0.read().unwrap();
f.debug_struct("CloseableCursor")
.field(
"data",
&std::str::from_utf8(&self.data.read().unwrap()).unwrap_or("not utf8"),
&std::str::from_utf8(&inner.data).unwrap_or("not utf8"),
)
.field("closed", &*self.closed.read().unwrap())
.field("cursor", &*self.cursor.read().unwrap())
.field("closed", &inner.closed)
.field("cursor", &inner.cursor)
.finish()
}
}
Expand All @@ -180,18 +193,17 @@ impl Read for &CloseableCursor {
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let len = self.len();
let cursor = self.cursor();
if cursor < len {
let data = &*self.data.read().unwrap();
let bytes_to_copy = buf.len().min(len - cursor);
buf[..bytes_to_copy].copy_from_slice(&data[cursor..cursor + bytes_to_copy]);
*self.cursor.write().unwrap() += bytes_to_copy;
let mut inner = self.0.write().unwrap();
if inner.cursor < inner.data.len() {
let bytes_to_copy = buf.len().min(inner.data.len() - inner.cursor);
buf[..bytes_to_copy]
.copy_from_slice(&inner.data[inner.cursor..inner.cursor + bytes_to_copy]);
inner.cursor += bytes_to_copy;
Poll::Ready(Ok(bytes_to_copy))
} else if *self.closed.read().unwrap() {
} else if inner.closed {
Poll::Ready(Ok(0))
} else {
*self.waker.write().unwrap() = Some(cx.waker().clone());
inner.waker = Some(cx.waker().clone());
Poll::Pending
}
}
Expand All @@ -203,11 +215,12 @@ impl Write for &CloseableCursor {
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
if *self.closed.read().unwrap() {
let mut inner = self.0.write().unwrap();
if inner.closed {
Poll::Ready(Ok(0))
} else {
self.data.write().unwrap().extend_from_slice(buf);
if let Some(waker) = self.waker.write().unwrap().take() {
inner.data.extend_from_slice(buf);
if let Some(waker) = inner.waker.take() {
waker.wake();
}
Poll::Ready(Ok(buf.len()))
Expand All @@ -219,10 +232,7 @@ impl Write for &CloseableCursor {
}

fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
if let Some(waker) = self.waker.write().unwrap().take() {
waker.wake();
}
*self.closed.write().unwrap() = true;
self.close();
Poll::Ready(Ok(()))
}
}
Expand Down