Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions codex-rs/core/src/codex_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3119,7 +3119,6 @@ async fn spawn_task_turn_span_inherits_dispatch_trace_context() {
captured_trace: Arc<std::sync::Mutex<Option<W3cTraceContext>>>,
}

#[async_trait::async_trait]
impl SessionTask for TraceCaptureTask {
fn kind(&self) -> TaskKind {
TaskKind::Regular
Expand Down Expand Up @@ -4375,7 +4374,6 @@ struct NeverEndingTask {
listen_to_cancellation_token: bool,
}

#[async_trait::async_trait]
impl SessionTask for NeverEndingTask {
fn kind(&self) -> TaskKind {
self.kind
Expand Down
4 changes: 2 additions & 2 deletions codex-rs/core/src/state/turn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use rmcp::model::RequestId;
use tokio::sync::oneshot;

use crate::codex::TurnContext;
use crate::tasks::SessionTask;
use crate::tasks::AnySessionTask;
use codex_protocol::models::PermissionProfile;
use codex_protocol::protocol::ReviewDecision;
use codex_protocol::protocol::TokenUsage;
Expand Down Expand Up @@ -69,7 +69,7 @@ pub(crate) enum TaskKind {
pub(crate) struct RunningTask {
pub(crate) done: Arc<Notify>,
pub(crate) kind: TaskKind,
pub(crate) task: Arc<dyn SessionTask>,
pub(crate) task: Arc<dyn AnySessionTask>,
pub(crate) cancellation_token: CancellationToken,
pub(crate) handle: Arc<AbortOnDropHandle<()>>,
pub(crate) turn_context: Arc<TurnContext>,
Expand Down
6 changes: 2 additions & 4 deletions codex-rs/core/src/tasks/compact.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,12 @@ use super::SessionTask;
use super::SessionTaskContext;
use crate::codex::TurnContext;
use crate::state::TaskKind;
use async_trait::async_trait;
use codex_protocol::user_input::UserInput;
use tokio_util::sync::CancellationToken;

#[derive(Clone, Copy, Default)]
pub(crate) struct CompactTask;

#[async_trait]
impl SessionTask for CompactTask {
fn kind(&self) -> TaskKind {
TaskKind::Compact
Expand All @@ -30,14 +28,14 @@ impl SessionTask for CompactTask {
) -> Option<String> {
let session = session.clone_session();
let _ = if crate::compact::should_use_remote_compact_task(&ctx.provider) {
let _ = session.services.session_telemetry.counter(
session.services.session_telemetry.counter(
"codex.task.compact",
/*inc*/ 1,
&[("type", "remote")],
);
crate::compact_remote::run_remote_compact_task(session.clone(), ctx).await
} else {
let _ = session.services.session_telemetry.counter(
session.services.session_telemetry.counter(
"codex.task.compact",
/*inc*/ 1,
&[("type", "local")],
Expand Down
2 changes: 0 additions & 2 deletions codex-rs/core/src/tasks/ghost_snapshot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use crate::codex::TurnContext;
use crate::state::TaskKind;
use crate::tasks::SessionTask;
use crate::tasks::SessionTaskContext;
use async_trait::async_trait;
use codex_git_utils::CreateGhostCommitOptions;
use codex_git_utils::GhostSnapshotReport;
use codex_git_utils::GitToolingError;
Expand All @@ -26,7 +25,6 @@ pub(crate) struct GhostSnapshotTask {

const SNAPSHOT_WARNING_THRESHOLD: Duration = Duration::from_secs(240);

#[async_trait]
impl SessionTask for GhostSnapshotTask {
fn kind(&self) -> TaskKind {
TaskKind::Regular
Expand Down
76 changes: 69 additions & 7 deletions codex-rs/core/src/tasks/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use std::sync::Arc;
use std::time::Duration;
use std::time::Instant;

use async_trait::async_trait;
use futures::future::BoxFuture;
use tokio::select;
use tokio::sync::Notify;
use tokio_util::sync::CancellationToken;
Expand Down Expand Up @@ -126,7 +126,6 @@ impl SessionTaskContext {
/// intentionally small: implementers identify themselves via
/// [`SessionTask::kind`], perform their work in [`SessionTask::run`], and may
/// release resources in [`SessionTask::abort`].
#[async_trait]
pub(crate) trait SessionTask: Send + Sync + 'static {
/// Describes the type of work the task performs so the session can
/// surface it in telemetry and UI.
Expand All @@ -143,21 +142,84 @@ pub(crate) trait SessionTask: Send + Sync + 'static {
/// abort; implementers should watch for it and terminate quickly once it
/// fires. Returning [`Some`] yields a final message that
/// [`Session::on_task_finished`] will emit to the client.
async fn run(
fn run(
self: Arc<Self>,
session: Arc<SessionTaskContext>,
ctx: Arc<TurnContext>,
input: Vec<UserInput>,
cancellation_token: CancellationToken,
) -> Option<String>;
) -> impl std::future::Future<Output = Option<String>> + Send;

/// Gives the task a chance to perform cleanup after an abort.
///
/// The default implementation is a no-op; override this if additional
/// teardown or notifications are required once
/// [`Session::abort_all_tasks`] cancels the task.
async fn abort(&self, session: Arc<SessionTaskContext>, ctx: Arc<TurnContext>) {
let _ = (session, ctx);
fn abort(
&self,
session: Arc<SessionTaskContext>,
ctx: Arc<TurnContext>,
) -> impl std::future::Future<Output = ()> + Send {
async move {
let _ = (session, ctx);
}
}
}

pub(crate) trait AnySessionTask: Send + Sync + 'static {
fn kind(&self) -> TaskKind;

fn span_name(&self) -> &'static str;

fn run(
self: Arc<Self>,
session: Arc<SessionTaskContext>,
ctx: Arc<TurnContext>,
input: Vec<UserInput>,
cancellation_token: CancellationToken,
) -> BoxFuture<'static, Option<String>>;

fn abort<'a>(
&'a self,
session: Arc<SessionTaskContext>,
ctx: Arc<TurnContext>,
) -> BoxFuture<'a, ()>;
}

impl<T> AnySessionTask for T
where
T: SessionTask,
{
fn kind(&self) -> TaskKind {
SessionTask::kind(self)
}

fn span_name(&self) -> &'static str {
SessionTask::span_name(self)
}

fn run(
self: Arc<Self>,
session: Arc<SessionTaskContext>,
ctx: Arc<TurnContext>,
input: Vec<UserInput>,
cancellation_token: CancellationToken,
) -> BoxFuture<'static, Option<String>> {
Box::pin(SessionTask::run(
self,
session,
ctx,
input,
cancellation_token,
))
}

fn abort<'a>(
&'a self,
session: Arc<SessionTaskContext>,
ctx: Arc<TurnContext>,
) -> BoxFuture<'a, ()> {
Box::pin(SessionTask::abort(self, session, ctx))
}
}

Expand All @@ -179,7 +241,7 @@ impl Session {
input: Vec<UserInput>,
task: T,
) {
let task: Arc<dyn SessionTask> = Arc::new(task);
let task: Arc<dyn AnySessionTask> = Arc::new(task);
let task_kind = task.kind();
let span_name = task.span_name();
let started_at = Instant::now();
Expand Down
2 changes: 0 additions & 2 deletions codex-rs/core/src/tasks/regular.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use std::sync::Arc;

use async_trait::async_trait;
use tokio_util::sync::CancellationToken;

use crate::codex::TurnContext;
Expand All @@ -25,7 +24,6 @@ impl RegularTask {
}
}

#[async_trait]
impl SessionTask for RegularTask {
fn kind(&self) -> TaskKind {
TaskKind::Regular
Expand Down
4 changes: 1 addition & 3 deletions codex-rs/core/src/tasks/review.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::borrow::Cow;
use std::sync::Arc;

use async_trait::async_trait;
use codex_protocol::config_types::WebSearchMode;
use codex_protocol::items::TurnItem;
use codex_protocol::models::ContentItem;
Expand Down Expand Up @@ -48,7 +47,6 @@ impl ReviewTask {
}
}

#[async_trait]
impl SessionTask for ReviewTask {
fn kind(&self) -> TaskKind {
TaskKind::Review
Expand All @@ -65,7 +63,7 @@ impl SessionTask for ReviewTask {
input: Vec<UserInput>,
cancellation_token: CancellationToken,
) -> Option<String> {
let _ = session.session.services.session_telemetry.counter(
session.session.services.session_telemetry.counter(
"codex.task.review",
/*inc*/ 1,
&[],
Expand Down
12 changes: 5 additions & 7 deletions codex-rs/core/src/tasks/undo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use crate::codex::TurnContext;
use crate::state::TaskKind;
use crate::tasks::SessionTask;
use crate::tasks::SessionTaskContext;
use async_trait::async_trait;
use codex_git_utils::RestoreGhostCommitOptions;
use codex_git_utils::restore_ghost_commit_with_options;
use codex_protocol::models::ResponseItem;
Expand All @@ -25,7 +24,6 @@ impl UndoTask {
}
}

#[async_trait]
impl SessionTask for UndoTask {
fn kind(&self) -> TaskKind {
TaskKind::Regular
Expand All @@ -42,11 +40,11 @@ impl SessionTask for UndoTask {
_input: Vec<UserInput>,
cancellation_token: CancellationToken,
) -> Option<String> {
let _ = session.session.services.session_telemetry.counter(
"codex.task.undo",
/*inc*/ 1,
&[],
);
session
.session
.services
.session_telemetry
.counter("codex.task.undo", /*inc*/ 1, &[]);
let sess = session.clone_session();
sess.send_event(
ctx.as_ref(),
Expand Down
2 changes: 0 additions & 2 deletions codex-rs/core/src/tasks/user_shell.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::sync::Arc;
use std::time::Duration;

use async_trait::async_trait;
use codex_async_utils::CancelErr;
use codex_async_utils::OrCancelExt;
use codex_protocol::user_input::UserInput;
Expand Down Expand Up @@ -62,7 +61,6 @@ impl UserShellCommandTask {
}
}

#[async_trait]
impl SessionTask for UserShellCommandTask {
fn kind(&self) -> TaskKind {
TaskKind::Regular
Expand Down
1 change: 0 additions & 1 deletion codex-rs/core/src/tools/handlers/multi_agents_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ fn history_contains_inter_agent_communication(
#[derive(Clone, Copy)]
struct NeverEndingTask;

#[async_trait::async_trait]
impl SessionTask for NeverEndingTask {
fn kind(&self) -> TaskKind {
TaskKind::Regular
Expand Down
Loading