Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 6 additions & 54 deletions codex-rs/core/src/goals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ use codex_utils_template::Template;
use futures::future::BoxFuture;
use std::sync::Arc;
use std::sync::LazyLock;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
use std::time::Duration;
use std::time::Instant;
use tokio::sync::Mutex;
Expand Down Expand Up @@ -90,7 +88,6 @@ pub(crate) enum GoalRuntimeEvent<'a> {
TurnFinished {
turn_context: &'a TurnContext,
turn_completed: bool,
tool_calls: u64,
},
MaybeContinueIfIdle,
TaskAborted {
Expand All @@ -112,7 +109,6 @@ pub(crate) struct GoalRuntimeState {
accounting: Mutex<GoalAccountingSnapshot>,
continuation_turn_id: Mutex<Option<String>>,
pub(crate) continuation_lock: Semaphore,
pub(crate) continuation_suppressed: AtomicBool,
}

struct GoalContinuationCandidate {
Expand All @@ -129,7 +125,6 @@ impl GoalRuntimeState {
accounting: Mutex::new(GoalAccountingSnapshot::new()),
continuation_turn_id: Mutex::new(None),
continuation_lock: Semaphore::new(/*permits*/ 1),
continuation_suppressed: AtomicBool::new(false),
}
}
}
Expand Down Expand Up @@ -277,8 +272,8 @@ impl Session {
/// suppresses that steering, external mutations account best-effort before
/// changing state, interrupts pause active goals, resumes reactivate paused
/// goals, explicit maybe-continue events start idle goal continuation turns,
/// and no-tool continuation turns suppress the next automatic continuation
/// until user/tool/external activity resets it.
/// and continuation turns with no counted autonomous activity suppress the
/// next automatic continuation until user/tool/external activity resets it.
pub(crate) fn goal_runtime_apply<'a>(
self: &'a Arc<Self>,
event: GoalRuntimeEvent<'a>,
Expand All @@ -296,25 +291,22 @@ impl Session {
turn_context,
tool_name,
} => Box::pin(async move {
self.reset_thread_goal_continuation_suppression();
if tool_name != codex_tools::UPDATE_GOAL_TOOL_NAME {
self.account_thread_goal_progress(turn_context, BudgetLimitSteering::Allowed)
.await?;
}
Ok(())
}),
GoalRuntimeEvent::ToolCompletedGoal { turn_context } => Box::pin(async move {
self.reset_thread_goal_continuation_suppression();
self.account_thread_goal_progress(turn_context, BudgetLimitSteering::Suppressed)
.await?;
Ok(())
}),
GoalRuntimeEvent::TurnFinished {
turn_context,
turn_completed,
tool_calls,
} => Box::pin(async move {
self.finish_thread_goal_turn(turn_context, turn_completed, tool_calls)
self.finish_thread_goal_turn(turn_context, turn_completed)
.await;
Ok(())
}),
Expand All @@ -331,7 +323,6 @@ impl Session {
Ok(())
}),
GoalRuntimeEvent::ExternalMutationStarting => Box::pin(async move {
self.reset_thread_goal_continuation_suppression();
if let Err(err) = self.account_thread_goal_before_external_mutation().await {
tracing::warn!(
"failed to account thread goal progress before external mutation: {err}"
Expand Down Expand Up @@ -463,7 +454,6 @@ impl Session {
let goal_status = goal.status;
let goal_id = goal.goal_id.clone();
let goal = protocol_goal_from_state(goal);
self.reset_thread_goal_continuation_suppression();
*self.goal_runtime.budget_limit_reported_goal_id.lock().await = None;
let newly_active_goal = goal_status == codex_state::ThreadGoalStatus::Active
&& (replacing_goal
Expand Down Expand Up @@ -532,7 +522,6 @@ impl Session {

let goal_id = goal.goal_id.clone();
let goal = protocol_goal_from_state(goal);
self.reset_thread_goal_continuation_suppression();
*self.goal_runtime.budget_limit_reported_goal_id.lock().await = None;

let current_token_usage = self.total_token_usage().await.unwrap_or_default();
Expand Down Expand Up @@ -561,7 +550,6 @@ impl Session {
) {
match status {
codex_state::ThreadGoalStatus::Active => {
self.reset_thread_goal_continuation_suppression();
match self.state_db_for_thread_goals().await {
Ok(Some(state_db)) => {
match state_db.get_thread_goal(self.conversation_id).await {
Expand Down Expand Up @@ -608,7 +596,6 @@ impl Session {
}

async fn clear_stopped_thread_goal_runtime_state(&self) {
self.reset_thread_goal_continuation_suppression();
*self.goal_runtime.budget_limit_reported_goal_id.lock().await = None;
let mut accounting = self.goal_runtime.accounting.lock().await;
if let Some(turn) = accounting.turn.as_mut() {
Expand Down Expand Up @@ -663,16 +650,6 @@ impl Session {
turn_context: &TurnContext,
token_usage: TokenUsage,
) {
if self
.goal_runtime
.continuation_turn_id
.lock()
.await
.as_ref()
.is_none_or(|turn_id| turn_id != &turn_context.sub_id)
{
self.reset_thread_goal_continuation_suppression();
}
self.goal_runtime.accounting.lock().await.turn = Some(GoalTurnAccountingSnapshot::new(
turn_context.sub_id.clone(),
token_usage,
Expand Down Expand Up @@ -723,12 +700,6 @@ impl Session {
}
}

fn reset_thread_goal_continuation_suppression(&self) {
self.goal_runtime
.continuation_suppressed
.store(false, Ordering::SeqCst);
}

async fn mark_thread_goal_continuation_turn_started(&self, turn_id: String) {
*self.goal_runtime.continuation_turn_id.lock().await = Some(turn_id);
}
Expand Down Expand Up @@ -757,7 +728,6 @@ impl Session {
self: &Arc<Self>,
turn_context: &TurnContext,
turn_completed: bool,
turn_tool_calls: u64,
) {
if turn_completed
&& let Err(err) = self
Expand All @@ -767,15 +737,8 @@ impl Session {
tracing::warn!("failed to account thread goal progress at turn end: {err}");
}

if self
.take_thread_goal_continuation_turn(&turn_context.sub_id)
.await
&& turn_tool_calls == 0
{
self.goal_runtime
.continuation_suppressed
.store(true, Ordering::SeqCst);
}
self.take_thread_goal_continuation_turn(&turn_context.sub_id)
.await;
if turn_completed {
let mut accounting = self.goal_runtime.accounting.lock().await;
if accounting
Expand Down Expand Up @@ -1126,7 +1089,6 @@ impl Session {
};
let goal_id = goal.goal_id.clone();
let goal = protocol_goal_from_state(goal);
self.reset_thread_goal_continuation_suppression();
*self.goal_runtime.budget_limit_reported_goal_id.lock().await = None;
let active_turn_id = self
.active_turn_context()
Expand Down Expand Up @@ -1255,16 +1217,6 @@ impl Session {
);
return None;
}
if self
.goal_runtime
.continuation_suppressed
.load(Ordering::SeqCst)
{
tracing::debug!(
"skipping active goal continuation because the last continuation made no tool calls"
);
return None;
}
let state_db = match self.state_db_for_thread_goals().await {
Ok(Some(state_db)) => state_db,
Ok(None) => {
Expand Down Expand Up @@ -1578,7 +1530,7 @@ mod tests {
assert!(prompt.contains("<untrusted_objective>\nfinish the stack\n</untrusted_objective>"));
assert!(prompt.contains("Token budget: 10000"));
assert!(prompt.contains("call update_goal with status \"complete\""));
assert!(prompt.contains(
assert!(!prompt.contains(
"explain the blocker or next required input to the user and wait for new input"
));
assert!(!prompt.contains("budgetLimited"));
Expand Down
136 changes: 131 additions & 5 deletions codex-rs/core/src/session/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ use codex_protocol::protocol::TurnCompleteEvent;
use codex_protocol::protocol::TurnStartedEvent;
use codex_protocol::protocol::UserMessageEvent;
use codex_protocol::protocol::W3cTraceContext;
use codex_protocol::request_user_input::RequestUserInputAnswer;
use codex_protocol::request_user_input::RequestUserInputResponse;
use core_test_support::PathBufExt;
use core_test_support::PathExt;
use core_test_support::context_snapshot;
Expand All @@ -136,6 +138,7 @@ use core_test_support::test_codex::test_codex;
use core_test_support::test_path_buf;
use core_test_support::tracing::install_test_tracing;
use core_test_support::wait_for_event;
use core_test_support::wait_for_event_match;
use opentelemetry::trace::TraceContextExt;
use opentelemetry::trace::TraceId;
use opentelemetry_sdk::metrics::InMemoryMetricExporter;
Expand Down Expand Up @@ -6941,7 +6944,7 @@ async fn interrupt_accounts_active_goal_before_pausing() -> anyhow::Result<()> {
}

#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn active_goal_continuation_runs_to_completion_after_turn() -> anyhow::Result<()> {
async fn active_goal_continuation_runs_again_after_no_tool_turn() -> anyhow::Result<()> {
let server = start_mock_server().await;
let mut builder = test_codex().with_config(|config| {
config
Expand All @@ -6967,18 +6970,107 @@ async fn active_goal_continuation_runs_to_completion_after_turn() -> anyhow::Res
ev_completed("resp-2"),
]),
sse(vec![
ev_response_created("resp-3"),
ev_assistant_message("msg-2", "I am still working on the benchmark note."),
ev_completed("resp-3"),
]),
sse(vec![
ev_response_created("resp-4"),
ev_function_call(
"call-complete-goal",
"update_goal",
r#"{"status":"complete"}"#,
),
ev_completed("resp-4"),
]),
sse(vec![
ev_assistant_message("msg-3", "Goal complete."),
ev_completed("resp-5"),
]),
],
)
.await;

test.codex
.submit(Op::UserInput {
environments: None,
items: vec![UserInput::Text {
text: "write a benchmark note".into(),
text_elements: Vec::new(),
}],
final_output_json_schema: None,
responsesapi_client_metadata: None,
})
.await?;

let mut completed_turns = 0;
tokio::time::timeout(std::time::Duration::from_secs(8), async {
loop {
let event = test.codex.next_event().await?;
if matches!(event.msg, EventMsg::TurnComplete(_)) {
completed_turns += 1;
if completed_turns == 3 {
return anyhow::Ok(());
}
}
}
})
.await??;

Ok(())
}

#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn pending_request_user_input_does_not_spawn_extra_goal_continuation() -> anyhow::Result<()> {
let server = start_mock_server().await;
let mut builder = test_codex().with_config(|config| {
config
.features
.enable(Feature::Goals)
.expect("goal mode should be enableable in tests");
config
.features
.enable(Feature::DefaultModeRequestUserInput)
.expect("default-mode request_user_input should be enableable in tests");
});
let test = builder.build(&server).await?;
let responses = mount_sse_sequence(
&server,
vec![
sse(vec![
ev_response_created("resp-1"),
ev_function_call(
"call-create-goal",
"create_goal",
r#"{"objective":"write a benchmark note"}"#,
),
ev_completed("resp-1"),
]),
sse(vec![
ev_assistant_message("msg-1", "Draft ready."),
ev_completed("resp-2"),
]),
sse(vec![
ev_response_created("resp-3"),
ev_function_call(
"call-ask-user",
"request_user_input",
r#"{"questions":[{"header":"Choice","id":"next_step","question":"Pick one","options":[{"label":"Outline","description":"Start with an outline."},{"label":"Draft","description":"Write a full draft."}]}]}"#,
),
ev_completed("resp-3"),
]),
sse(vec![
ev_assistant_message("msg-2", "Goal complete."),
ev_response_created("resp-4"),
ev_function_call(
"call-complete-goal",
"update_goal",
r#"{"status":"complete"}"#,
),
ev_completed("resp-4"),
]),
sse(vec![
ev_assistant_message("msg-2", "Goal complete."),
ev_completed("resp-5"),
]),
],
)
.await;
Expand All @@ -6995,20 +7087,54 @@ async fn active_goal_continuation_runs_to_completion_after_turn() -> anyhow::Res
})
.await?;

let request_user_input_event = wait_for_event_match(&test.codex, |event| match event {
EventMsg::RequestUserInput(event) => Some(event.clone()),
_ => None,
})
.await;
assert_eq!(3, responses.requests().len());
assert!(
timeout(Duration::from_millis(200), test.codex.next_event())
.await
.is_err(),
"waiting for request_user_input should keep the turn open without emitting more events"
);
assert_eq!(
3,
responses.requests().len(),
"waiting for request_user_input should not start another continuation request"
);

test.codex
.submit(Op::UserInputAnswer {
id: request_user_input_event.turn_id,
response: RequestUserInputResponse {
answers: std::collections::HashMap::from([(
"next_step".to_string(),
RequestUserInputAnswer {
answers: vec!["Outline".to_string()],
},
)]),
},
})
.await?;

let mut completed_turns = 0;
tokio::time::timeout(std::time::Duration::from_secs(8), async {
timeout(Duration::from_secs(8), async {
loop {
let event = test.codex.next_event().await?;
if matches!(event.msg, EventMsg::TurnComplete(_)) {
completed_turns += 1;
if completed_turns == 2 {
if completed_turns == 1 {
return anyhow::Ok(());
}
}
}
})
.await??;

assert_eq!(5, responses.requests().len());

Ok(())
}

Expand Down
Loading
Loading