From 2cafd783ac22e70083773157387068e6883b75eb Mon Sep 17 00:00:00 2001 From: Michael Bolin Date: Thu, 2 Apr 2026 16:15:18 -0700 Subject: [PATCH] core: use native async SessionTask trait --- codex-rs/core/src/codex_tests.rs | 2 - codex-rs/core/src/state/turn.rs | 4 +- codex-rs/core/src/tasks/compact.rs | 6 +- codex-rs/core/src/tasks/ghost_snapshot.rs | 2 - codex-rs/core/src/tasks/mod.rs | 76 +++++++++++++++++-- codex-rs/core/src/tasks/regular.rs | 2 - codex-rs/core/src/tasks/review.rs | 4 +- codex-rs/core/src/tasks/undo.rs | 12 ++- codex-rs/core/src/tasks/user_shell.rs | 2 - .../src/tools/handlers/multi_agents_tests.rs | 1 - 10 files changed, 79 insertions(+), 32 deletions(-) diff --git a/codex-rs/core/src/codex_tests.rs b/codex-rs/core/src/codex_tests.rs index 4ff31e8bd00..c9391207395 100644 --- a/codex-rs/core/src/codex_tests.rs +++ b/codex-rs/core/src/codex_tests.rs @@ -3119,7 +3119,6 @@ async fn spawn_task_turn_span_inherits_dispatch_trace_context() { captured_trace: Arc>>, } - #[async_trait::async_trait] impl SessionTask for TraceCaptureTask { fn kind(&self) -> TaskKind { TaskKind::Regular @@ -4375,7 +4374,6 @@ struct NeverEndingTask { listen_to_cancellation_token: bool, } -#[async_trait::async_trait] impl SessionTask for NeverEndingTask { fn kind(&self) -> TaskKind { self.kind diff --git a/codex-rs/core/src/state/turn.rs b/codex-rs/core/src/state/turn.rs index a8e3e167b53..214fd8be150 100644 --- a/codex-rs/core/src/state/turn.rs +++ b/codex-rs/core/src/state/turn.rs @@ -18,7 +18,7 @@ use rmcp::model::RequestId; use tokio::sync::oneshot; use crate::codex::TurnContext; -use crate::tasks::SessionTask; +use crate::tasks::AnySessionTask; use codex_protocol::models::PermissionProfile; use codex_protocol::protocol::ReviewDecision; use codex_protocol::protocol::TokenUsage; @@ -69,7 +69,7 @@ pub(crate) enum TaskKind { pub(crate) struct RunningTask { pub(crate) done: Arc, pub(crate) kind: TaskKind, - pub(crate) task: Arc, + pub(crate) task: Arc, pub(crate) cancellation_token: CancellationToken, pub(crate) handle: Arc>, pub(crate) turn_context: Arc, diff --git a/codex-rs/core/src/tasks/compact.rs b/codex-rs/core/src/tasks/compact.rs index a2d94bdc0aa..8c7998d8537 100644 --- a/codex-rs/core/src/tasks/compact.rs +++ b/codex-rs/core/src/tasks/compact.rs @@ -4,14 +4,12 @@ use super::SessionTask; use super::SessionTaskContext; use crate::codex::TurnContext; use crate::state::TaskKind; -use async_trait::async_trait; use codex_protocol::user_input::UserInput; use tokio_util::sync::CancellationToken; #[derive(Clone, Copy, Default)] pub(crate) struct CompactTask; -#[async_trait] impl SessionTask for CompactTask { fn kind(&self) -> TaskKind { TaskKind::Compact @@ -30,14 +28,14 @@ impl SessionTask for CompactTask { ) -> Option { let session = session.clone_session(); let _ = if crate::compact::should_use_remote_compact_task(&ctx.provider) { - let _ = session.services.session_telemetry.counter( + session.services.session_telemetry.counter( "codex.task.compact", /*inc*/ 1, &[("type", "remote")], ); crate::compact_remote::run_remote_compact_task(session.clone(), ctx).await } else { - let _ = session.services.session_telemetry.counter( + session.services.session_telemetry.counter( "codex.task.compact", /*inc*/ 1, &[("type", "local")], diff --git a/codex-rs/core/src/tasks/ghost_snapshot.rs b/codex-rs/core/src/tasks/ghost_snapshot.rs index 7b848098a4f..a9cf19cb4d6 100644 --- a/codex-rs/core/src/tasks/ghost_snapshot.rs +++ b/codex-rs/core/src/tasks/ghost_snapshot.rs @@ -2,7 +2,6 @@ use crate::codex::TurnContext; use crate::state::TaskKind; use crate::tasks::SessionTask; use crate::tasks::SessionTaskContext; -use async_trait::async_trait; use codex_git_utils::CreateGhostCommitOptions; use codex_git_utils::GhostSnapshotReport; use codex_git_utils::GitToolingError; @@ -26,7 +25,6 @@ pub(crate) struct GhostSnapshotTask { const SNAPSHOT_WARNING_THRESHOLD: Duration = Duration::from_secs(240); -#[async_trait] impl SessionTask for GhostSnapshotTask { fn kind(&self) -> TaskKind { TaskKind::Regular diff --git a/codex-rs/core/src/tasks/mod.rs b/codex-rs/core/src/tasks/mod.rs index 74b2e471cd8..c0b5ad4d91d 100644 --- a/codex-rs/core/src/tasks/mod.rs +++ b/codex-rs/core/src/tasks/mod.rs @@ -9,7 +9,7 @@ use std::sync::Arc; use std::time::Duration; use std::time::Instant; -use async_trait::async_trait; +use futures::future::BoxFuture; use tokio::select; use tokio::sync::Notify; use tokio_util::sync::CancellationToken; @@ -126,7 +126,6 @@ impl SessionTaskContext { /// intentionally small: implementers identify themselves via /// [`SessionTask::kind`], perform their work in [`SessionTask::run`], and may /// release resources in [`SessionTask::abort`]. -#[async_trait] pub(crate) trait SessionTask: Send + Sync + 'static { /// Describes the type of work the task performs so the session can /// surface it in telemetry and UI. @@ -143,21 +142,84 @@ pub(crate) trait SessionTask: Send + Sync + 'static { /// abort; implementers should watch for it and terminate quickly once it /// fires. Returning [`Some`] yields a final message that /// [`Session::on_task_finished`] will emit to the client. - async fn run( + fn run( self: Arc, session: Arc, ctx: Arc, input: Vec, cancellation_token: CancellationToken, - ) -> Option; + ) -> impl std::future::Future> + Send; /// Gives the task a chance to perform cleanup after an abort. /// /// The default implementation is a no-op; override this if additional /// teardown or notifications are required once /// [`Session::abort_all_tasks`] cancels the task. - async fn abort(&self, session: Arc, ctx: Arc) { - let _ = (session, ctx); + fn abort( + &self, + session: Arc, + ctx: Arc, + ) -> impl std::future::Future + Send { + async move { + let _ = (session, ctx); + } + } +} + +pub(crate) trait AnySessionTask: Send + Sync + 'static { + fn kind(&self) -> TaskKind; + + fn span_name(&self) -> &'static str; + + fn run( + self: Arc, + session: Arc, + ctx: Arc, + input: Vec, + cancellation_token: CancellationToken, + ) -> BoxFuture<'static, Option>; + + fn abort<'a>( + &'a self, + session: Arc, + ctx: Arc, + ) -> BoxFuture<'a, ()>; +} + +impl AnySessionTask for T +where + T: SessionTask, +{ + fn kind(&self) -> TaskKind { + SessionTask::kind(self) + } + + fn span_name(&self) -> &'static str { + SessionTask::span_name(self) + } + + fn run( + self: Arc, + session: Arc, + ctx: Arc, + input: Vec, + cancellation_token: CancellationToken, + ) -> BoxFuture<'static, Option> { + Box::pin(SessionTask::run( + self, + session, + ctx, + input, + cancellation_token, + )) + } + + fn abort<'a>( + &'a self, + session: Arc, + ctx: Arc, + ) -> BoxFuture<'a, ()> { + Box::pin(SessionTask::abort(self, session, ctx)) } } @@ -179,7 +241,7 @@ impl Session { input: Vec, task: T, ) { - let task: Arc = Arc::new(task); + let task: Arc = Arc::new(task); let task_kind = task.kind(); let span_name = task.span_name(); let started_at = Instant::now(); diff --git a/codex-rs/core/src/tasks/regular.rs b/codex-rs/core/src/tasks/regular.rs index 7a274d534fb..f2a29ee7ab6 100644 --- a/codex-rs/core/src/tasks/regular.rs +++ b/codex-rs/core/src/tasks/regular.rs @@ -1,6 +1,5 @@ use std::sync::Arc; -use async_trait::async_trait; use tokio_util::sync::CancellationToken; use crate::codex::TurnContext; @@ -25,7 +24,6 @@ impl RegularTask { } } -#[async_trait] impl SessionTask for RegularTask { fn kind(&self) -> TaskKind { TaskKind::Regular diff --git a/codex-rs/core/src/tasks/review.rs b/codex-rs/core/src/tasks/review.rs index e0c6033483a..a1cc071108c 100644 --- a/codex-rs/core/src/tasks/review.rs +++ b/codex-rs/core/src/tasks/review.rs @@ -1,7 +1,6 @@ use std::borrow::Cow; use std::sync::Arc; -use async_trait::async_trait; use codex_protocol::config_types::WebSearchMode; use codex_protocol::items::TurnItem; use codex_protocol::models::ContentItem; @@ -48,7 +47,6 @@ impl ReviewTask { } } -#[async_trait] impl SessionTask for ReviewTask { fn kind(&self) -> TaskKind { TaskKind::Review @@ -65,7 +63,7 @@ impl SessionTask for ReviewTask { input: Vec, cancellation_token: CancellationToken, ) -> Option { - let _ = session.session.services.session_telemetry.counter( + session.session.services.session_telemetry.counter( "codex.task.review", /*inc*/ 1, &[], diff --git a/codex-rs/core/src/tasks/undo.rs b/codex-rs/core/src/tasks/undo.rs index 48cdf11aabe..dd655300924 100644 --- a/codex-rs/core/src/tasks/undo.rs +++ b/codex-rs/core/src/tasks/undo.rs @@ -4,7 +4,6 @@ use crate::codex::TurnContext; use crate::state::TaskKind; use crate::tasks::SessionTask; use crate::tasks::SessionTaskContext; -use async_trait::async_trait; use codex_git_utils::RestoreGhostCommitOptions; use codex_git_utils::restore_ghost_commit_with_options; use codex_protocol::models::ResponseItem; @@ -25,7 +24,6 @@ impl UndoTask { } } -#[async_trait] impl SessionTask for UndoTask { fn kind(&self) -> TaskKind { TaskKind::Regular @@ -42,11 +40,11 @@ impl SessionTask for UndoTask { _input: Vec, cancellation_token: CancellationToken, ) -> Option { - let _ = session.session.services.session_telemetry.counter( - "codex.task.undo", - /*inc*/ 1, - &[], - ); + session + .session + .services + .session_telemetry + .counter("codex.task.undo", /*inc*/ 1, &[]); let sess = session.clone_session(); sess.send_event( ctx.as_ref(), diff --git a/codex-rs/core/src/tasks/user_shell.rs b/codex-rs/core/src/tasks/user_shell.rs index 4d4ff383ac5..449db837fd8 100644 --- a/codex-rs/core/src/tasks/user_shell.rs +++ b/codex-rs/core/src/tasks/user_shell.rs @@ -1,7 +1,6 @@ use std::sync::Arc; use std::time::Duration; -use async_trait::async_trait; use codex_async_utils::CancelErr; use codex_async_utils::OrCancelExt; use codex_protocol::user_input::UserInput; @@ -62,7 +61,6 @@ impl UserShellCommandTask { } } -#[async_trait] impl SessionTask for UserShellCommandTask { fn kind(&self) -> TaskKind { TaskKind::Regular diff --git a/codex-rs/core/src/tools/handlers/multi_agents_tests.rs b/codex-rs/core/src/tools/handlers/multi_agents_tests.rs index ee81b961e31..fa231e66b8c 100644 --- a/codex-rs/core/src/tools/handlers/multi_agents_tests.rs +++ b/codex-rs/core/src/tools/handlers/multi_agents_tests.rs @@ -115,7 +115,6 @@ fn history_contains_inter_agent_communication( #[derive(Clone, Copy)] struct NeverEndingTask; -#[async_trait::async_trait] impl SessionTask for NeverEndingTask { fn kind(&self) -> TaskKind { TaskKind::Regular