Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Notify to coordinate waiters #186

Merged
merged 12 commits into from
Jan 29, 2024
4 changes: 2 additions & 2 deletions bb8/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "bb8"
version = "0.8.1"
version = "0.8.2"
edition = "2021"
rust-version = "1.63"
description = "Full-featured async (tokio-based) connection pool (like r2d2)"
Expand All @@ -14,7 +14,7 @@ async-trait = "0.1"
futures-channel = "0.3.2"
futures-util = { version = "0.3.2", default-features = false, features = ["channel"] }
parking_lot = { version = "0.12", optional = true }
tokio = { version = "1.0", features = ["rt", "time"] }
tokio = { version = "1.0", features = ["rt", "sync", "time"] }

[dev-dependencies]
tokio = { version = "1.0", features = ["macros"] }
Expand Down
17 changes: 6 additions & 11 deletions bb8/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@ impl<M: ManageConnection> Pool<M> {
/// Using an owning `PooledConnection` makes it easier to leak the connection pool. Therefore, [`Pool::get`]
/// (which stores a lifetime-bound reference to the pool) should be preferred whenever possible.
pub async fn get_owned(&self) -> Result<PooledConnection<'static, M>, RunError<M::Error>> {
self.inner.get_owned().await
Ok(PooledConnection {
conn: self.get().await?.take(),
pool: Cow::Owned(self.inner.clone()),
})
}

/// Get a new dedicated connection that will not be managed by the pool.
Expand Down Expand Up @@ -385,17 +388,9 @@ where
pub(crate) fn drop_invalid(mut self) {
let _ = self.conn.take();
}
}

impl<M> PooledConnection<'static, M>
where
M: ManageConnection,
{
pub(crate) fn new_owned(pool: PoolInner<M>, conn: Conn<M::Connection>) -> Self {
Self {
pool: Cow::Owned(pool),
conn: Some(conn),
}
pub(crate) fn take(&mut self) -> Option<Conn<M::Connection>> {
self.conn.take()
}
}

Expand Down
124 changes: 47 additions & 77 deletions bb8/src/inner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use std::future::Future;
use std::sync::{Arc, Weak};
use std::time::{Duration, Instant};

use futures_channel::oneshot;
use futures_util::stream::{FuturesUnordered, StreamExt};
use futures_util::TryFutureExt;
use tokio::spawn;
Expand All @@ -28,12 +27,15 @@ where
let inner = Arc::new(SharedPool::new(builder, manager));

if inner.statics.max_lifetime.is_some() || inner.statics.idle_timeout.is_some() {
let s = Arc::downgrade(&inner);
if let Some(shared) = s.upgrade() {
let start = Instant::now() + shared.statics.reaper_rate;
let interval = interval_at(start.into(), shared.statics.reaper_rate);
schedule_reaping(interval, s);
}
let start = Instant::now() + inner.statics.reaper_rate;
let interval = interval_at(start.into(), inner.statics.reaper_rate);
tokio::spawn(
Reaper {
interval,
pool: Arc::downgrade(&inner),
}
.run(),
);
}

Self { inner }
Expand Down Expand Up @@ -83,65 +85,35 @@ where
}

pub(crate) async fn get(&self) -> Result<PooledConnection<'_, M>, RunError<M::Error>> {
self.make_pooled(|this, conn| PooledConnection::new(this, conn))
.await
}

pub(crate) async fn get_owned(
&self,
) -> Result<PooledConnection<'static, M>, RunError<M::Error>> {
self.make_pooled(|this, conn| {
let pool = PoolInner {
inner: Arc::clone(&this.inner),
};
PooledConnection::new_owned(pool, conn)
})
.await
}

pub(crate) async fn make_pooled<'a, 'b, F>(
&'a self,
make_pooled_conn: F,
) -> Result<PooledConnection<'b, M>, RunError<M::Error>>
where
F: Fn(&'a Self, Conn<M::Connection>) -> PooledConnection<'b, M>,
{
loop {
let mut conn = {
let mut locked = self.inner.internals.lock();
match locked.pop(&self.inner.statics) {
Some((conn, approvals)) => {
self.spawn_replenishing_approvals(approvals);
make_pooled_conn(self, conn)
let future = async {
loop {
let (conn, approvals) = self.inner.pop();
self.spawn_replenishing_approvals(approvals);
let mut conn = match conn {
Some(conn) => PooledConnection::new(self, conn),
None => {
self.inner.notify.notified().await;
continue;
}
None => break,
}
};
};

if !self.inner.statics.test_on_check_out {
return Ok(conn);
}
if !self.inner.statics.test_on_check_out {
return Ok(conn);
}

match self.inner.manager.is_valid(&mut conn).await {
Ok(()) => return Ok(conn),
Err(e) => {
self.inner.forward_error(e);
conn.drop_invalid();
continue;
match self.inner.manager.is_valid(&mut conn).await {
Ok(()) => return Ok(conn),
Err(e) => {
self.inner.forward_error(e);
conn.drop_invalid();
continue;
}
}
}
}

let (tx, rx) = oneshot::channel();
{
let mut locked = self.inner.internals.lock();
let approvals = locked.push_waiter(tx, &self.inner.statics);
self.spawn_replenishing_approvals(approvals);
};

match timeout(self.inner.statics.connection_timeout, rx).await {
Ok(Ok(Ok(mut guard))) => Ok(make_pooled_conn(self, guard.extract())),
Ok(Ok(Err(e))) => Err(RunError::User(e)),
match timeout(self.inner.statics.connection_timeout, future).await {
Ok(result) => result,
_ => Err(RunError::TimedOut),
}
}
Expand Down Expand Up @@ -177,12 +149,6 @@ where
self.inner.internals.lock().state()
}

fn reap(&self) {
let mut internals = self.inner.internals.lock();
let approvals = internals.reap(&self.inner.statics);
self.spawn_replenishing_approvals(approvals);
}

// Outside of Pool to avoid borrow splitting issues on self
async fn add_connection(&self, approval: Approval) -> Result<(), M::Error>
where
Expand Down Expand Up @@ -257,18 +223,22 @@ where
}
}

fn schedule_reaping<M>(mut interval: Interval, weak_shared: Weak<SharedPool<M>>)
where
M: ManageConnection,
{
spawn(async move {
struct Reaper<M: ManageConnection> {
interval: Interval,
pool: Weak<SharedPool<M>>,
}

impl<M: ManageConnection> Reaper<M> {
async fn run(mut self) {
loop {
let _ = interval.tick().await;
if let Some(inner) = weak_shared.upgrade() {
PoolInner { inner }.reap();
} else {
break;
}
let _ = self.interval.tick().await;
let pool = match self.pool.upgrade() {
Some(inner) => PoolInner { inner },
None => break,
};

let approvals = pool.inner.reap();
pool.spawn_replenishing_approvals(approvals);
}
});
}
}
95 changes: 23 additions & 72 deletions bb8/src/internals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::sync::Arc;
use std::time::Instant;

use crate::{api::QueueStrategy, lock::Mutex};
use futures_channel::oneshot;
use tokio::sync::Notify;

use crate::api::{Builder, ManageConnection};
use std::collections::VecDeque;
Expand All @@ -17,6 +17,7 @@ where
pub(crate) statics: Builder<M>,
pub(crate) manager: M,
pub(crate) internals: Mutex<PoolInternals<M>>,
pub(crate) notify: Arc<Notify>,
}

impl<M> SharedPool<M>
Expand All @@ -28,18 +29,27 @@ where
statics,
manager,
internals: Mutex::new(PoolInternals::default()),
notify: Arc::new(Notify::new()),
}
}

pub(crate) fn forward_error(&self, mut err: M::Error) {
pub(crate) fn pop(&self) -> (Option<Conn<M::Connection>>, ApprovalIter) {
let mut locked = self.internals.lock();
while let Some(waiter) = locked.waiters.pop_front() {
match waiter.send(Err(err)) {
Ok(_) => return,
Err(Err(e)) => err = e,
Err(Ok(_)) => unreachable!(),
}
}
let conn = locked.conns.pop_front().map(|idle| idle.conn);
let approvals = match &conn {
Some(_) => locked.wanted(&self.statics),
None => locked.approvals(&self.statics, 1),
};

(conn, approvals)
}

pub(crate) fn reap(&self) -> ApprovalIter {
let mut locked = self.internals.lock();
locked.reap(&self.statics)
}

pub(crate) fn forward_error(&self, err: M::Error) {
self.statics.error_sink.sink(err);
}
}
Expand All @@ -50,7 +60,6 @@ pub(crate) struct PoolInternals<M>
where
M: ManageConnection,
{
waiters: VecDeque<oneshot::Sender<Result<InternalsGuard<M>, M::Error>>>,
conns: VecDeque<IdleConn<M::Connection>>,
num_conns: u32,
pending_conns: u32,
Expand All @@ -60,15 +69,6 @@ impl<M> PoolInternals<M>
where
M: ManageConnection,
{
pub(crate) fn pop(
&mut self,
config: &Builder<M>,
) -> Option<(Conn<M::Connection>, ApprovalIter)> {
self.conns
.pop_front()
.map(|idle| (idle.conn, self.wanted(config)))
}

pub(crate) fn put(
&mut self,
conn: Conn<M::Connection>,
Expand All @@ -80,26 +80,14 @@ where
self.num_conns += 1;
}

let queue_strategy = pool.statics.queue_strategy;

let mut guard = InternalsGuard::new(conn, pool);
while let Some(waiter) = self.waiters.pop_front() {
// This connection is no longer idle, send it back out
match waiter.send(Ok(guard)) {
Ok(()) => return,
Err(Ok(g)) => {
guard = g;
}
Err(Err(_)) => unreachable!(),
}
}

// Queue it in the idle queue
let conn = IdleConn::from(guard.conn.take().unwrap());
match queue_strategy {
let conn = IdleConn::from(conn);
match pool.statics.queue_strategy {
QueueStrategy::Fifo => self.conns.push_back(conn),
QueueStrategy::Lifo => self.conns.push_front(conn),
}

pool.notify.notify_one();
}

pub(crate) fn connect_failed(&mut self, _: Approval) {
Expand All @@ -123,15 +111,6 @@ where
self.approvals(config, wanted)
}

pub(crate) fn push_waiter(
&mut self,
waiter: oneshot::Sender<Result<InternalsGuard<M>, M::Error>>,
config: &Builder<M>,
) -> ApprovalIter {
self.waiters.push_back(waiter);
self.approvals(config, 1)
}

fn approvals(&mut self, config: &Builder<M>, num: u32) -> ApprovalIter {
let current = self.num_conns + self.pending_conns;
let allowed = if current < config.max_size {
Expand Down Expand Up @@ -177,41 +156,13 @@ where
{
fn default() -> Self {
Self {
waiters: VecDeque::new(),
conns: VecDeque::new(),
num_conns: 0,
pending_conns: 0,
}
}
}

pub(crate) struct InternalsGuard<M: ManageConnection> {
conn: Option<Conn<M::Connection>>,
pool: Arc<SharedPool<M>>,
}

impl<M: ManageConnection> InternalsGuard<M> {
fn new(conn: Conn<M::Connection>, pool: Arc<SharedPool<M>>) -> Self {
Self {
conn: Some(conn),
pool,
}
}

pub(crate) fn extract(&mut self) -> Conn<M::Connection> {
self.conn.take().unwrap() // safe: can only be `None` after `Drop`
}
}

impl<M: ManageConnection> Drop for InternalsGuard<M> {
fn drop(&mut self) {
if let Some(conn) = self.conn.take() {
let mut locked = self.pool.internals.lock();
locked.put(conn, None, self.pool.clone());
}
}
}

#[must_use]
pub(crate) struct ApprovalIter {
num: usize,
Expand Down
Loading
Loading