Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "openab"
version = "0.7.4"
version = "0.7.5"
edition = "2021"

[dependencies]
Expand Down
289 changes: 227 additions & 62 deletions src/acp/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,22 @@ use crate::acp::connection::AcpConnection;
use crate::config::AgentConfig;
use anyhow::{anyhow, Result};
use std::collections::HashMap;
use tokio::sync::RwLock;
use std::sync::Arc;
use tokio::sync::{Mutex, RwLock};
use tokio::time::Instant;
use tracing::{info, warn};

/// Combined state protected by a single lock to prevent deadlocks.
/// Lock ordering: always acquire `state` before any operation on either map.
/// Lock ordering: never await a per-connection mutex while holding `state`.
struct PoolState {
/// Active connections: thread_key → AcpConnection.
active: HashMap<String, AcpConnection>,
/// Active connections: thread_key → AcpConnection handle.
active: HashMap<String, Arc<Mutex<AcpConnection>>>,
/// Suspended sessions: thread_key → ACP sessionId.
/// Saved on eviction so sessions can be resumed via `session/load`.
suspended: HashMap<String, String>,
/// Serializes create/resume work per thread so rapid same-thread requests
/// cannot race each other into duplicate `session/load` attempts.
creating: HashMap<String, Arc<Mutex<()>>>,
}

pub struct SessionPool {
Expand All @@ -22,71 +26,121 @@ pub struct SessionPool {
max_sessions: usize,
}

type EvictionCandidate = (
String,
Arc<Mutex<AcpConnection>>,
Instant,
Option<String>,
);

fn remove_if_same_handle<T>(
map: &mut HashMap<String, Arc<Mutex<T>>>,
key: &str,
expected: &Arc<Mutex<T>>,
) -> Option<Arc<Mutex<T>>> {
let should_remove = map
.get(key)
.is_some_and(|current| Arc::ptr_eq(current, expected));
if should_remove {
map.remove(key)
} else {
None
}
}

fn get_or_insert_gate(
map: &mut HashMap<String, Arc<Mutex<()>>>,
key: &str,
) -> Arc<Mutex<()>> {
map.entry(key.to_string())
.or_insert_with(|| Arc::new(Mutex::new(())))
.clone()
}

impl SessionPool {
pub fn new(config: AgentConfig, max_sessions: usize) -> Self {
Self {
state: RwLock::new(PoolState {
active: HashMap::new(),
suspended: HashMap::new(),
creating: HashMap::new(),
}),
config,
max_sessions,
}
}

pub async fn get_or_create(&self, thread_id: &str) -> Result<()> {
// Check if alive connection exists
{
let state = self.state.read().await;
if let Some(conn) = state.active.get(thread_id) {
if conn.alive() {
return Ok(());
}
}
}
let create_gate = {
let mut state = self.state.write().await;
get_or_insert_gate(&mut state.creating, thread_id)
};
let _create_guard = create_gate.lock().await;

// Need to create or rebuild
let mut state = self.state.write().await;
let (existing, saved_session_id) = {
let state = self.state.read().await;
(
state.active.get(thread_id).cloned(),
state.suspended.get(thread_id).cloned(),
)
};

// Double-check after acquiring write lock
if let Some(conn) = state.active.get(thread_id) {
let had_existing = existing.is_some();
let mut saved_session_id = saved_session_id;
if let Some(conn) = existing.clone() {
let conn = conn.lock().await;
if conn.alive() {
return Ok(());
}
warn!(thread_id, "stale connection, rebuilding");
suspend_entry(&mut state, thread_id);
if saved_session_id.is_none() {
saved_session_id = conn.acp_session_id.clone();
}
}

if state.active.len() >= self.max_sessions {
// LRU evict: suspend the oldest idle session to make room
let oldest = state.active
// Snapshot active handles so we can inspect them outside the state lock.
let snapshot: Vec<(String, Arc<Mutex<AcpConnection>>)> = {
let state = self.state.read().await;
state
.active
.iter()
.min_by_key(|(_, c)| c.last_active)
.map(|(k, _)| k.clone());
if let Some(key) = oldest {
info!(evicted = %key, "pool full, suspending oldest idle session");
suspend_entry(&mut state, &key);
} else {
return Err(anyhow!("pool exhausted ({} sessions)", self.max_sessions));
.map(|(k, v)| (k.clone(), Arc::clone(v)))
.collect()
};

let mut eviction_candidate: Option<EvictionCandidate> = None;
let mut skipped_locked_candidates = 0usize;
for (key, conn) in snapshot {
if key == thread_id {
continue;
}
let conn_handle = Arc::clone(&conn);
let Ok(conn) = conn.try_lock() else {
skipped_locked_candidates += 1;
continue;
};
let candidate = (key, conn_handle, conn.last_active, conn.acp_session_id.clone());
match &eviction_candidate {
Some((_, _, oldest_last_active, _)) if candidate.2 >= *oldest_last_active => {}
_ => eviction_candidate = Some(candidate),
}
}

let mut conn = AcpConnection::spawn(
// Build the replacement connection outside the state lock so one stuck
// initialization does not block all unrelated sessions.
let mut new_conn = AcpConnection::spawn(
&self.config.command,
&self.config.args,
&self.config.working_dir,
&self.config.env,
)
.await?;

conn.initialize().await?;
new_conn.initialize().await?;

// Try to resume a suspended session via session/load
let saved_session_id = state.suspended.remove(thread_id);
let mut resumed = false;
if let Some(ref sid) = saved_session_id {
if conn.supports_load_session {
match conn.session_load(sid, &self.config.working_dir).await {
if new_conn.supports_load_session {
match new_conn.session_load(sid, &self.config.working_dir).await {
Ok(()) => {
info!(thread_id, session_id = %sid, "session resumed via session/load");
resumed = true;
Expand All @@ -99,39 +153,119 @@ impl SessionPool {
}

if !resumed {
conn.session_new(&self.config.working_dir).await?;
if saved_session_id.is_some() {
conn.session_reset = true;
new_conn.session_new(&self.config.working_dir).await?;
// Surface the reset banner both for restored sessions and for stale
// live entries that died before we could recover a resumable
// session id. In both cases the caller is continuing after an
// unexpected session loss.
if had_existing || saved_session_id.is_some() {
new_conn.session_reset = true;
}
}

let new_conn = Arc::new(Mutex::new(new_conn));

let mut state = self.state.write().await;

// Another task may have created a healthy connection while we were
// initializing this one.
if let Some(existing) = state.active.get(thread_id).cloned() {
let Ok(existing) = existing.try_lock() else {
return Ok(());
};
if existing.alive() {
return Ok(());
}
warn!(thread_id, "stale connection, rebuilding");
drop(existing);
state.active.remove(thread_id);
}

if state.active.len() >= self.max_sessions {
if let Some((key, expected_conn, _, sid)) = eviction_candidate {
if remove_if_same_handle(&mut state.active, &key, &expected_conn).is_some() {
info!(evicted = %key, "pool full, suspending oldest idle session");
if let Some(sid) = sid {
state.suspended.insert(key, sid);
}
} else {
warn!(evicted = %key, "pool full but eviction candidate changed before removal");
}
} else if skipped_locked_candidates > 0 {
warn!(
max_sessions = self.max_sessions,
skipped_locked_candidates,
"pool full but all other sessions were busy during eviction scan"
);
}
}

state.active.insert(thread_id.to_string(), conn);
if state.active.len() >= self.max_sessions {
return Err(anyhow!("pool exhausted ({} sessions)", self.max_sessions));
}

state.suspended.remove(thread_id);
state.active.insert(thread_id.to_string(), new_conn);
Ok(())
}

/// Get mutable access to a connection. Caller must have called get_or_create first.
pub async fn with_connection<F, R>(&self, thread_id: &str, f: F) -> Result<R>
where
F: FnOnce(&mut AcpConnection) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<R>> + Send + '_>>,
F: for<'a> FnOnce(
&'a mut AcpConnection,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<R>> + Send + 'a>>,
{
let mut state = self.state.write().await;
let conn = state.active
.get_mut(thread_id)
.ok_or_else(|| anyhow!("no connection for thread {thread_id}"))?;
f(conn).await
let conn = {
let state = self.state.read().await;
state
.active
.get(thread_id)
.cloned()
.ok_or_else(|| anyhow!("no connection for thread {thread_id}"))?
};

let mut conn = conn.lock().await;
f(&mut conn).await
}

pub async fn cleanup_idle(&self, ttl_secs: u64) {
let cutoff = Instant::now() - std::time::Duration::from_secs(ttl_secs);

let snapshot: Vec<(String, Arc<Mutex<AcpConnection>>)> = {
let state = self.state.read().await;
state
.active
.iter()
.map(|(k, v)| (k.clone(), Arc::clone(v)))
.collect()
};

let mut stale = Vec::new();
for (key, conn) in snapshot {
// Skip active sessions for this cleanup round instead of waiting on
// their per-connection mutex. A busy session is not idle.
let conn_handle = Arc::clone(&conn);
let Ok(conn) = conn.try_lock() else {
continue;
};
if conn.last_active < cutoff || !conn.alive() {
stale.push((key, conn_handle, conn.acp_session_id.clone()));
}
}

if stale.is_empty() {
return;
}

let mut state = self.state.write().await;
let stale: Vec<String> = state.active
.iter()
.filter(|(_, c)| c.last_active < cutoff || !c.alive())
.map(|(k, _)| k.clone())
.collect();
for key in stale {
info!(thread_id = %key, "cleaning up idle session");
suspend_entry(&mut state, &key);
for (key, expected_conn, sid) in stale {
if remove_if_same_handle(&mut state.active, &key, &expected_conn).is_some() {
info!(thread_id = %key, "cleaning up idle session");
if let Some(sid) = sid {
state.suspended.insert(key, sid);
}
}
}
}

Expand All @@ -143,14 +277,45 @@ impl SessionPool {
}
}

/// Suspend a connection: save its sessionId to the suspended map and remove
/// from active. The connection is dropped, triggering process group kill.
fn suspend_entry(state: &mut PoolState, thread_id: &str) {
if let Some(conn) = state.active.remove(thread_id) {
if let Some(sid) = &conn.acp_session_id {
info!(thread_id, session_id = %sid, "suspending session");
state.suspended.insert(thread_id.to_string(), sid.clone());
}
// conn dropped here → Drop impl kills process group
#[cfg(test)]
mod tests {
use super::{get_or_insert_gate, remove_if_same_handle};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Mutex;

#[test]
fn remove_if_same_handle_removes_matching_entry() {
let expected = Arc::new(Mutex::new(1_u8));
let mut map = HashMap::from([("thread".to_string(), Arc::clone(&expected))]);

let removed = remove_if_same_handle(&mut map, "thread", &expected);

assert!(removed.is_some());
assert!(map.is_empty());
}

#[test]
fn remove_if_same_handle_keeps_replaced_entry() {
let stale = Arc::new(Mutex::new(1_u8));
let fresh = Arc::new(Mutex::new(2_u8));
let mut map = HashMap::from([("thread".to_string(), Arc::clone(&fresh))]);

let removed = remove_if_same_handle(&mut map, "thread", &stale);

assert!(removed.is_none());
let current = map.get("thread").expect("entry should remain");
assert!(Arc::ptr_eq(current, &fresh));
}

#[test]
fn get_or_insert_gate_reuses_gate_for_same_thread() {
let mut map = HashMap::new();

let first = get_or_insert_gate(&mut map, "thread");
let second = get_or_insert_gate(&mut map, "thread");

assert!(Arc::ptr_eq(&first, &second));
assert_eq!(map.len(), 1);
}
}
Loading