From 24c620e1d2e5674ee4f94b7bbdfe5de9f3c3fbe3 Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Tue, 19 May 2026 17:04:09 -0400 Subject: [PATCH] refactor(lib): use a panic_if_poisoned() helper for mutexes --- src/common/lock.rs | 15 +++++++++++++++ src/common/mod.rs | 1 + src/ffi/task.rs | 10 ++++++---- src/mock.rs | 20 +++++++++++--------- src/proto/h2/ping.rs | 9 +++++---- src/upgrade.rs | 3 ++- 6 files changed, 40 insertions(+), 18 deletions(-) create mode 100644 src/common/lock.rs diff --git a/src/common/lock.rs b/src/common/lock.rs new file mode 100644 index 0000000000..4e2b70bff2 --- /dev/null +++ b/src/common/lock.rs @@ -0,0 +1,15 @@ +use std::sync::LockResult; + +pub(crate) trait LockResultExt { + fn panic_if_poisoned(self) -> T; +} + +impl LockResultExt for LockResult { + #[track_caller] + fn panic_if_poisoned(self) -> T { + match self { + Ok(inner) => inner, + Err(err) => panic!("lock poisoned by panic: {err}"), + } + } +} diff --git a/src/common/mod.rs b/src/common/mod.rs index 4b73437203..5be740b000 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -10,6 +10,7 @@ pub(crate) mod either; ))] pub(crate) mod future; pub(crate) mod io; +pub(crate) mod lock; #[cfg(all(any(feature = "client", feature = "server"), feature = "http1"))] pub(crate) mod task; #[cfg(any( diff --git a/src/ffi/task.rs b/src/ffi/task.rs index 0c49936c6e..47c74c5dbf 100644 --- a/src/ffi/task.rs +++ b/src/ffi/task.rs @@ -10,6 +10,8 @@ use std::task::{Context, Poll}; use futures_util::stream::{FuturesUnordered, Stream}; +use crate::common::lock::LockResultExt; + use super::error::hyper_code; use super::UserDataPointer; @@ -196,7 +198,7 @@ impl hyper_executor { fn spawn(&self, task: Box) { self.spawn_queue .lock() - .unwrap() + .panic_if_poisoned() .push(TaskFuture { task: Some(task) }); } @@ -211,7 +213,7 @@ impl hyper_executor { { // Scope the lock on the driver to ensure it is dropped before // calling drain_queue below. - let mut driver = self.driver.lock().unwrap(); + let mut driver = self.driver.lock().panic_if_poisoned(); match Pin::new(&mut *driver).poll_next(&mut cx) { Poll::Ready(val) => return val, Poll::Pending => {} @@ -238,12 +240,12 @@ impl hyper_executor { /// drain_queue locks both self.spawn_queue and self.driver, so it requires /// that neither of them be locked already. fn drain_queue(&self) -> bool { - let mut queue = self.spawn_queue.lock().unwrap(); + let mut queue = self.spawn_queue.lock().panic_if_poisoned(); if queue.is_empty() { return false; } - let driver = self.driver.lock().unwrap(); + let driver = self.driver.lock().panic_if_poisoned(); for task in queue.drain(..) { driver.push(task); diff --git a/src/mock.rs b/src/mock.rs index 1dd57de319..322ea94f1d 100644 --- a/src/mock.rs +++ b/src/mock.rs @@ -16,6 +16,8 @@ use tokio_io::{AsyncRead, AsyncWrite}; #[cfg(feature = "runtime")] use crate::client::connect::{Connect, Connected, Destination}; +#[cfg(feature = "runtime")] +use crate::common::lock::LockResultExt; @@ -59,14 +61,14 @@ impl Duplex { #[cfg(feature = "runtime")] impl Read for Duplex { fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.inner.lock().unwrap().read.read(buf) + self.inner.lock().panic_if_poisoned().read.read(buf) } } #[cfg(feature = "runtime")] impl Write for Duplex { fn write(&mut self, buf: &[u8]) -> io::Result { - let mut inner = self.inner.lock().unwrap(); + let mut inner = self.inner.lock().panic_if_poisoned(); let ret = inner.write.write(buf); if let Some(task) = inner.handle_read_task.take() { trace!("waking DuplexHandle read"); @@ -76,7 +78,7 @@ impl Write for Duplex { } fn flush(&mut self) -> io::Result<()> { - self.inner.lock().unwrap().write.flush() + self.inner.lock().panic_if_poisoned().write.flush() } } @@ -91,7 +93,7 @@ impl AsyncWrite for Duplex { } fn write_buf(&mut self, buf: &mut B) -> Poll { - let mut inner = self.inner.lock().unwrap(); + let mut inner = self.inner.lock().panic_if_poisoned(); if let Some(task) = inner.handle_read_task.take() { task.notify(); } @@ -107,7 +109,7 @@ pub struct DuplexHandle { #[cfg(feature = "runtime")] impl DuplexHandle { pub fn read(&self, buf: &mut [u8]) -> Poll { - let mut inner = self.inner.lock().unwrap(); + let mut inner = self.inner.lock().panic_if_poisoned(); assert!(buf.len() >= inner.write.inner.len()); if inner.write.inner.is_empty() { trace!("DuplexHandle read parking"); @@ -118,7 +120,7 @@ impl DuplexHandle { } pub fn write(&self, bytes: &[u8]) -> Poll { - let mut inner = self.inner.lock().unwrap(); + let mut inner = self.inner.lock().panic_if_poisoned(); assert_eq!(inner.read.inner.pos, 0); assert_eq!(inner.read.inner.vec.len(), 0, "write but read isn't empty"); inner @@ -136,7 +138,7 @@ impl Drop for DuplexHandle { fn drop(&mut self) { trace!("mock duplex handle drop"); if !::std::thread::panicking() { - let mut inner = self.inner.lock().unwrap(); + let mut inner = self.inner.lock().panic_if_poisoned(); inner.read.close(); inner.write.close(); } @@ -187,7 +189,7 @@ impl MockConnector { trace!("MockConnector mocked fut ready"); Ok((duplex, connected)) })); - self.mocks.lock().unwrap().0.entry(key) + self.mocks.lock().panic_if_poisoned().0.entry(key) .or_insert(Vec::new()) .push(fut); @@ -208,7 +210,7 @@ impl Connect for MockConnector { } else { "".to_owned() }); - let mut mocks = self.mocks.lock().unwrap(); + let mut mocks = self.mocks.lock().panic_if_poisoned(); let mocks = mocks.0.get_mut(&key) .expect(&format!("unknown mocks uri: {}", key)); assert!(!mocks.is_empty(), "no additional mocks for {}", key); diff --git a/src/proto/h2/ping.rs b/src/proto/h2/ping.rs index 4952e38518..741c066a3f 100644 --- a/src/proto/h2/ping.rs +++ b/src/proto/h2/ping.rs @@ -28,6 +28,7 @@ use std::time::{Duration, Instant}; use h2::{Ping, PingPong}; +use crate::common::lock::LockResultExt; use crate::common::time::Time; use crate::rt::Sleep; @@ -196,7 +197,7 @@ impl Recorder { return; }; - let mut locked = shared.lock().unwrap(); + let mut locked = shared.lock().panic_if_poisoned(); locked.update_last_read_at(); @@ -230,7 +231,7 @@ impl Recorder { return; }; - let mut locked = shared.lock().unwrap(); + let mut locked = shared.lock().panic_if_poisoned(); locked.update_last_read_at(); } @@ -248,7 +249,7 @@ impl Recorder { pub(super) fn ensure_not_timed_out(&self) -> crate::Result<()> { if let Some(ref shared) = self.shared { - let locked = shared.lock().unwrap(); + let locked = shared.lock().panic_if_poisoned(); if locked.is_keep_alive_timed_out { return Err(KeepAliveTimedOut.crate_error()); } @@ -263,7 +264,7 @@ impl Recorder { impl Ponger { pub(super) fn poll(&mut self, cx: &mut task::Context<'_>) -> Poll { - let mut locked = self.shared.lock().unwrap(); + let mut locked = self.shared.lock().panic_if_poisoned(); let now = locked.timer.now(); // hoping this is fine to move within the lock let is_idle = self.is_idle(); diff --git a/src/upgrade.rs b/src/upgrade.rs index 9d23a29081..de144cf071 100644 --- a/src/upgrade.rs +++ b/src/upgrade.rs @@ -53,6 +53,7 @@ use bytes::Bytes; use tokio::sync::oneshot; use crate::common::io::Rewind; +use crate::common::lock::LockResultExt; /// An upgraded HTTP connection. /// @@ -226,7 +227,7 @@ impl Future for OnUpgrade { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.rx { - Some(ref rx) => Pin::new(&mut *rx.lock().unwrap()) + Some(ref rx) => Pin::new(&mut *rx.lock().panic_if_poisoned()) .poll(cx) .map(|res| match res { Ok(Ok(upgraded)) => Ok(upgraded),