diff --git a/codex-rs/core/src/session/handlers.rs b/codex-rs/core/src/session/handlers.rs index a3de362d61b4..d4d7bea3b470 100644 --- a/codex-rs/core/src/session/handlers.rs +++ b/codex-rs/core/src/session/handlers.rs @@ -617,7 +617,7 @@ pub async fn set_thread_memory_mode(sess: &Arc, sub_id: String, mode: T } } -pub async fn shutdown(sess: &Arc, sub_id: String) -> bool { +async fn shutdown_session_runtime(sess: &Arc) { sess.abort_all_tasks(TurnAbortReason::Interrupted).await; let _ = sess.conversation.shutdown().await; sess.services @@ -630,6 +630,20 @@ pub async fn shutdown(sess: &Arc, sub_id: String) -> bool { }; mcp_shutdown.await; sess.guardian_review_session.shutdown().await; +} + +fn emit_thread_stop_lifecycle(sess: &Session) { + for contributor in sess.services.extensions.thread_lifecycle_contributors() { + contributor.on_thread_stop(codex_extension_api::ThreadStopInput { + thread_id: sess.conversation_id, + session_store: &sess.services.session_extension_data, + thread_store: &sess.services.thread_extension_data, + }); + } +} + +pub async fn shutdown(sess: &Arc, sub_id: String) -> bool { + shutdown_session_runtime(sess).await; info!("Shutting down Codex instance"); let history = sess.clone_history().await; let turn_count = history @@ -643,13 +657,7 @@ pub async fn shutdown(sess: &Arc, sub_id: String) -> bool { &[], ); - for contributor in sess.services.extensions.thread_lifecycle_contributors() { - contributor.on_thread_stop(codex_extension_api::ThreadStopInput { - thread_id: sess.conversation_id, - session_store: &sess.services.session_extension_data, - thread_store: &sess.services.thread_extension_data, - }); - } + emit_thread_stop_lifecycle(sess.as_ref()); // Gracefully flush and shutdown thread persistence on session end so tests // that inspect durable state do not race with the background writer. @@ -722,6 +730,7 @@ pub(super) async fn submission_loop( rx_sub: Receiver, ) { // To break out of this loop, send Op::Shutdown. + let mut shutdown_received = false; while let Ok(sub) = rx_sub.recv().await { debug!(?sub, "Submission"); let dispatch_span = submission_dispatch_span(&sub); @@ -894,23 +903,16 @@ pub(super) async fn submission_loop( .instrument(dispatch_span) .await; if should_exit { + shutdown_received = true; break; } } // If the submission loop exits because the channel closed without an - // explicit shutdown op, still run process teardown for child processes - // owned by this session. - sess.services - .unified_exec_manager - .terminate_all_processes() - .await; - let mcp_shutdown = { - let mut manager = sess.services.mcp_connection_manager.write().await; - manager.begin_shutdown() - }; - mcp_shutdown.await; - // Also drain cached guardian state on this implicit shutdown path. - sess.guardian_review_session.shutdown().await; + // explicit shutdown op, still run session teardown. + if !shutdown_received { + shutdown_session_runtime(&sess).await; + emit_thread_stop_lifecycle(sess.as_ref()); + } debug!("Agent loop exited"); } diff --git a/codex-rs/core/src/session/tests.rs b/codex-rs/core/src/session/tests.rs index c5602ddaaea9..5fcb652f8cd1 100644 --- a/codex-rs/core/src/session/tests.rs +++ b/codex-rs/core/src/session/tests.rs @@ -5196,6 +5196,116 @@ async fn shutdown_complete_does_not_append_to_thread_store_after_shutdown() { ); } +#[tokio::test] +async fn submission_loop_channel_close_emits_thread_stop_lifecycle() { + struct SessionStopMarker; + struct ThreadStopMarker; + + struct ThreadStopRecorder { + calls: Arc, + expected_thread_id: ThreadId, + } + + impl codex_extension_api::ThreadLifecycleContributor for ThreadStopRecorder { + fn on_thread_stop(&self, input: codex_extension_api::ThreadStopInput<'_>) { + assert_eq!(self.expected_thread_id, input.thread_id); + assert!(input.session_store.get::().is_some()); + assert!(input.thread_store.get::().is_some()); + self.calls.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + } + } + + let (mut session, turn_context) = make_session_and_context().await; + let calls = Arc::new(std::sync::atomic::AtomicUsize::new(0)); + let mut builder = codex_extension_api::ExtensionRegistryBuilder::::new(); + builder.thread_lifecycle_contributor(Arc::new(ThreadStopRecorder { + calls: Arc::clone(&calls), + expected_thread_id: session.conversation_id, + })); + session.services.extensions = Arc::new(builder.build()); + session + .services + .session_extension_data + .insert(SessionStopMarker); + session + .services + .thread_extension_data + .insert(ThreadStopMarker); + + let (tx_sub, rx_sub) = async_channel::bounded(1); + drop(tx_sub); + let session = Arc::new(session); + submission_loop(session, Arc::clone(&turn_context.config), rx_sub).await; + + assert_eq!(1, calls.load(std::sync::atomic::Ordering::SeqCst)); +} + +#[tokio::test] +async fn submission_loop_channel_close_aborts_active_turn_before_thread_stop_lifecycle() { + struct LifecycleRecorder { + calls: Arc>>, + expected_thread_id: ThreadId, + expected_turn_id: String, + } + + impl codex_extension_api::ThreadLifecycleContributor for LifecycleRecorder { + fn on_thread_stop(&self, input: codex_extension_api::ThreadStopInput<'_>) { + assert_eq!(self.expected_thread_id, input.thread_id); + self.calls + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .push("thread_stop"); + } + } + + impl codex_extension_api::TurnLifecycleContributor for LifecycleRecorder { + fn on_turn_abort(&self, input: codex_extension_api::TurnAbortInput<'_>) { + assert_eq!(self.expected_thread_id, input.thread_id); + assert_eq!(self.expected_turn_id, input.turn_id); + assert_eq!(TurnAbortReason::Interrupted, input.reason); + self.calls + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .push("turn_abort"); + } + } + + let (mut session, turn_context) = make_session_and_context().await; + let calls = Arc::new(std::sync::Mutex::new(Vec::new())); + let recorder = Arc::new(LifecycleRecorder { + calls: Arc::clone(&calls), + expected_thread_id: session.conversation_id, + expected_turn_id: turn_context.sub_id.clone(), + }); + let mut builder = codex_extension_api::ExtensionRegistryBuilder::::new(); + builder.thread_lifecycle_contributor(recorder.clone()); + builder.turn_lifecycle_contributor(recorder); + session.services.extensions = Arc::new(builder.build()); + + let session = Arc::new(session); + session + .spawn_task( + Arc::new(turn_context), + Vec::new(), + NeverEndingTask { + kind: TaskKind::Regular, + listen_to_cancellation_token: true, + }, + ) + .await; + + let (tx_sub, rx_sub) = async_channel::bounded(1); + drop(tx_sub); + submission_loop(Arc::clone(&session), session.get_config().await, rx_sub).await; + + assert_eq!( + vec!["turn_abort", "thread_stop"], + *calls + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + ); +} + #[tokio::test] async fn shutdown_and_wait_allows_multiple_waiters() { let (session, _turn_context) = make_session_and_context().await;