Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 23 additions & 21 deletions codex-rs/core/src/session/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ pub async fn set_thread_memory_mode(sess: &Arc<Session>, sub_id: String, mode: T
}
}

pub async fn shutdown(sess: &Arc<Session>, sub_id: String) -> bool {
async fn shutdown_session_runtime(sess: &Arc<Session>) {
sess.abort_all_tasks(TurnAbortReason::Interrupted).await;
let _ = sess.conversation.shutdown().await;
sess.services
Expand All @@ -630,6 +630,20 @@ pub async fn shutdown(sess: &Arc<Session>, sub_id: String) -> bool {
};
mcp_shutdown.await;
sess.guardian_review_session.shutdown().await;
}

fn emit_thread_stop_lifecycle(sess: &Session) {
for contributor in sess.services.extensions.thread_lifecycle_contributors() {
contributor.on_thread_stop(codex_extension_api::ThreadStopInput {
thread_id: sess.conversation_id,
session_store: &sess.services.session_extension_data,
thread_store: &sess.services.thread_extension_data,
});
}
}

pub async fn shutdown(sess: &Arc<Session>, sub_id: String) -> bool {
shutdown_session_runtime(sess).await;
info!("Shutting down Codex instance");
let history = sess.clone_history().await;
let turn_count = history
Expand All @@ -643,13 +657,7 @@ pub async fn shutdown(sess: &Arc<Session>, sub_id: String) -> bool {
&[],
);

for contributor in sess.services.extensions.thread_lifecycle_contributors() {
contributor.on_thread_stop(codex_extension_api::ThreadStopInput {
thread_id: sess.conversation_id,
session_store: &sess.services.session_extension_data,
thread_store: &sess.services.thread_extension_data,
});
}
emit_thread_stop_lifecycle(sess.as_ref());

// Gracefully flush and shutdown thread persistence on session end so tests
// that inspect durable state do not race with the background writer.
Expand Down Expand Up @@ -722,6 +730,7 @@ pub(super) async fn submission_loop(
rx_sub: Receiver<Submission>,
) {
// To break out of this loop, send Op::Shutdown.
let mut shutdown_received = false;
while let Ok(sub) = rx_sub.recv().await {
debug!(?sub, "Submission");
let dispatch_span = submission_dispatch_span(&sub);
Expand Down Expand Up @@ -894,23 +903,16 @@ pub(super) async fn submission_loop(
.instrument(dispatch_span)
.await;
if should_exit {
shutdown_received = true;
break;
}
}
// If the submission loop exits because the channel closed without an
// explicit shutdown op, still run process teardown for child processes
// owned by this session.
sess.services
.unified_exec_manager
.terminate_all_processes()
.await;
let mcp_shutdown = {
let mut manager = sess.services.mcp_connection_manager.write().await;
manager.begin_shutdown()
};
mcp_shutdown.await;
// Also drain cached guardian state on this implicit shutdown path.
sess.guardian_review_session.shutdown().await;
// explicit shutdown op, still run session teardown.
if !shutdown_received {
shutdown_session_runtime(&sess).await;
Comment thread
jif-oai marked this conversation as resolved.
emit_thread_stop_lifecycle(sess.as_ref());
Comment thread
jif-oai marked this conversation as resolved.
}
debug!("Agent loop exited");
}

Expand Down
110 changes: 110 additions & 0 deletions codex-rs/core/src/session/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5196,6 +5196,116 @@ async fn shutdown_complete_does_not_append_to_thread_store_after_shutdown() {
);
}

#[tokio::test]
async fn submission_loop_channel_close_emits_thread_stop_lifecycle() {
struct SessionStopMarker;
struct ThreadStopMarker;

struct ThreadStopRecorder {
calls: Arc<std::sync::atomic::AtomicUsize>,
expected_thread_id: ThreadId,
}

impl codex_extension_api::ThreadLifecycleContributor<crate::config::Config> for ThreadStopRecorder {
fn on_thread_stop(&self, input: codex_extension_api::ThreadStopInput<'_>) {
assert_eq!(self.expected_thread_id, input.thread_id);
assert!(input.session_store.get::<SessionStopMarker>().is_some());
assert!(input.thread_store.get::<ThreadStopMarker>().is_some());
self.calls.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
}
}

let (mut session, turn_context) = make_session_and_context().await;
let calls = Arc::new(std::sync::atomic::AtomicUsize::new(0));
let mut builder = codex_extension_api::ExtensionRegistryBuilder::<crate::config::Config>::new();
builder.thread_lifecycle_contributor(Arc::new(ThreadStopRecorder {
calls: Arc::clone(&calls),
expected_thread_id: session.conversation_id,
}));
session.services.extensions = Arc::new(builder.build());
session
.services
.session_extension_data
.insert(SessionStopMarker);
session
.services
.thread_extension_data
.insert(ThreadStopMarker);

let (tx_sub, rx_sub) = async_channel::bounded(1);
drop(tx_sub);
let session = Arc::new(session);
submission_loop(session, Arc::clone(&turn_context.config), rx_sub).await;

assert_eq!(1, calls.load(std::sync::atomic::Ordering::SeqCst));
}

#[tokio::test]
async fn submission_loop_channel_close_aborts_active_turn_before_thread_stop_lifecycle() {
struct LifecycleRecorder {
calls: Arc<std::sync::Mutex<Vec<&'static str>>>,
expected_thread_id: ThreadId,
expected_turn_id: String,
}

impl codex_extension_api::ThreadLifecycleContributor<crate::config::Config> for LifecycleRecorder {
fn on_thread_stop(&self, input: codex_extension_api::ThreadStopInput<'_>) {
assert_eq!(self.expected_thread_id, input.thread_id);
self.calls
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.push("thread_stop");
}
}

impl codex_extension_api::TurnLifecycleContributor for LifecycleRecorder {
fn on_turn_abort(&self, input: codex_extension_api::TurnAbortInput<'_>) {
assert_eq!(self.expected_thread_id, input.thread_id);
assert_eq!(self.expected_turn_id, input.turn_id);
assert_eq!(TurnAbortReason::Interrupted, input.reason);
self.calls
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.push("turn_abort");
}
}

let (mut session, turn_context) = make_session_and_context().await;
let calls = Arc::new(std::sync::Mutex::new(Vec::new()));
let recorder = Arc::new(LifecycleRecorder {
calls: Arc::clone(&calls),
expected_thread_id: session.conversation_id,
expected_turn_id: turn_context.sub_id.clone(),
});
let mut builder = codex_extension_api::ExtensionRegistryBuilder::<crate::config::Config>::new();
builder.thread_lifecycle_contributor(recorder.clone());
builder.turn_lifecycle_contributor(recorder);
session.services.extensions = Arc::new(builder.build());

let session = Arc::new(session);
session
.spawn_task(
Arc::new(turn_context),
Vec::new(),
NeverEndingTask {
kind: TaskKind::Regular,
listen_to_cancellation_token: true,
},
)
.await;

let (tx_sub, rx_sub) = async_channel::bounded(1);
drop(tx_sub);
submission_loop(Arc::clone(&session), session.get_config().await, rx_sub).await;

assert_eq!(
vec!["turn_abort", "thread_stop"],
*calls
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
);
}

#[tokio::test]
async fn shutdown_and_wait_allows_multiple_waiters() {
let (session, _turn_context) = make_session_and_context().await;
Expand Down
Loading