Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,14 @@ impl ThreadGoalRequestProcessor {

let (goal, previous_status) = (if let Some(objective) = objective {
let existing_goal = state_db
.thread_goals()
.get_thread_goal(thread_id)
.await
.map_err(|err| invalid_request(err.to_string()))?;
if let Some(goal) = existing_goal.as_ref() {
let previous_status = ExternalGoalPreviousStatus::from(goal);
state_db
.thread_goals()
.update_thread_goal(
thread_id,
codex_state::ThreadGoalUpdate {
Expand All @@ -177,6 +179,7 @@ impl ThreadGoalRequestProcessor {
} else {
let previous_status = ExternalGoalPreviousStatus::NewGoal;
state_db
.thread_goals()
.replace_thread_goal(
thread_id,
objective,
Expand All @@ -188,6 +191,7 @@ impl ThreadGoalRequestProcessor {
}
} else {
let existing_goal = state_db
.thread_goals()
.get_thread_goal(thread_id)
.await
.map_err(|err| invalid_request(err.to_string()))?;
Expand All @@ -198,6 +202,7 @@ impl ThreadGoalRequestProcessor {
};
let previous_status = ExternalGoalPreviousStatus::from(&existing_goal);
state_db
.thread_goals()
.update_thread_goal(
thread_id,
codex_state::ThreadGoalUpdate {
Expand Down Expand Up @@ -246,6 +251,7 @@ impl ThreadGoalRequestProcessor {
let thread_id = parse_thread_id_for_request(params.thread_id.as_str())?;
let state_db = self.state_db_for_materialized_thread(thread_id).await?;
let goal = state_db
.thread_goals()
.get_thread_goal(thread_id)
.await
.map_err(|err| internal_error(format!("failed to read thread goal: {err}")))?
Expand Down Expand Up @@ -303,6 +309,7 @@ impl ThreadGoalRequestProcessor {
thread_state.listener_command_tx()
};
let cleared = state_db
.thread_goals()
.delete_thread_goal(thread_id)
.await
.map_err(|err| internal_error(format!("failed to clear thread goal: {err}")))?;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,7 @@ pub(super) async fn send_thread_goal_snapshot_notification(
thread_id: ThreadId,
state_db: &StateDbHandle,
) {
match state_db.get_thread_goal(thread_id).await {
match state_db.thread_goals().get_thread_goal(thread_id).await {
Ok(Some(goal)) => {
outgoing
.send_server_notification(ServerNotification::ThreadGoalUpdated(
Expand Down
3 changes: 3 additions & 0 deletions codex-rs/app-server/tests/suite/v2/thread_resume.rs
Original file line number Diff line number Diff line change
Expand Up @@ -943,10 +943,12 @@ async fn thread_goal_set_edits_objective_without_resetting_usage() -> Result<()>
StateRuntime::init(codex_home.path().to_path_buf(), "mock_provider".into()).await?;
let thread_id = ThreadId::from_string(&thread_id)?;
let persisted_goal = state_db
.thread_goals()
.get_thread_goal(thread_id)
.await?
.expect("goal should exist");
state_db
.thread_goals()
.account_thread_goal_usage(
thread_id,
/*time_delta_seconds*/ 12,
Expand Down Expand Up @@ -974,6 +976,7 @@ async fn thread_goal_set_edits_objective_without_resetting_usage() -> Result<()>
.await??;
let edit: ThreadGoalSetResponse = to_response(edit_resp)?;
let updated_goal = state_db
.thread_goals()
.get_thread_goal(thread_id)
.await?
.expect("goal should still exist");
Expand Down
47 changes: 40 additions & 7 deletions codex-rs/core/src/goals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,7 @@ impl Session {

let state_db = self.require_state_db_for_thread_goals().await?;
state_db
.thread_goals()
.get_thread_goal(self.conversation_id)
.await
.map(|goal| goal.map(protocol_goal_from_state))
Expand Down Expand Up @@ -459,10 +460,14 @@ impl Session {
let mut replacing_goal = false;
let previous_status;
let goal = if let Some(objective) = objective.as_deref() {
let existing_goal = state_db.get_thread_goal(self.conversation_id).await?;
let existing_goal = state_db
.thread_goals()
.get_thread_goal(self.conversation_id)
.await?;
previous_status = existing_goal.as_ref().map(|goal| goal.status);
if let Some(existing_goal) = existing_goal.as_ref() {
state_db
.thread_goals()
.update_thread_goal(
self.conversation_id,
codex_state::ThreadGoalUpdate {
Expand All @@ -482,6 +487,7 @@ impl Session {
} else {
replacing_goal = true;
state_db
.thread_goals()
.replace_thread_goal(
self.conversation_id,
objective,
Expand All @@ -493,11 +499,15 @@ impl Session {
.await?
}
} else {
let existing_goal = state_db.get_thread_goal(self.conversation_id).await?;
let existing_goal = state_db
.thread_goals()
.get_thread_goal(self.conversation_id)
.await?;
previous_status = existing_goal.as_ref().map(|goal| goal.status);
let expected_goal_id = existing_goal.map(|goal| goal.goal_id);
let status = status.map(state_goal_status_from_protocol);
state_db
.thread_goals()
.update_thread_goal(
self.conversation_id,
codex_state::ThreadGoalUpdate {
Expand Down Expand Up @@ -581,6 +591,7 @@ impl Session {
)
.await?;
let goal = state_db
.thread_goals()
.insert_thread_goal(
self.conversation_id,
objective,
Expand Down Expand Up @@ -760,7 +771,10 @@ impl Session {
state_db: &StateDbHandle,
expected_goal_id: Option<&str>,
) -> anyhow::Result<Option<codex_state::ThreadGoalStatus>> {
let goal = state_db.get_thread_goal(self.conversation_id).await?;
let goal = state_db
.thread_goals()
.get_thread_goal(self.conversation_id)
.await?;
Ok(goal.and_then(|goal| {
expected_goal_id
.is_none_or(|expected_goal_id| goal.goal_id == expected_goal_id)
Expand Down Expand Up @@ -801,7 +815,11 @@ impl Session {
return;
}
};
match state_db.get_thread_goal(self.conversation_id).await {
match state_db
.thread_goals()
.get_thread_goal(self.conversation_id)
.await
{
Ok(Some(goal))
if matches!(
goal.status,
Expand Down Expand Up @@ -963,6 +981,7 @@ impl Session {
.current_goal_status_for_metrics(&state_db, expected_goal_id.as_deref())
.await?;
let outcome = state_db
.thread_goals()
.account_thread_goal_usage(
self.conversation_id,
time_delta_seconds,
Expand Down Expand Up @@ -1085,6 +1104,7 @@ impl Session {
.await?;

match state_db
.thread_goals()
.account_thread_goal_usage(
self.conversation_id,
time_delta_seconds,
Expand Down Expand Up @@ -1147,6 +1167,7 @@ impl Session {
)
.await?;
let Some(goal) = state_db
.thread_goals()
.pause_active_thread_goal(self.conversation_id)
.await?
else {
Expand Down Expand Up @@ -1192,7 +1213,11 @@ impl Session {
let Some(state_db) = self.state_db_for_thread_goals().await? else {
return Ok(());
};
let Some(goal) = state_db.get_thread_goal(self.conversation_id).await? else {
let Some(goal) = state_db
.thread_goals()
.get_thread_goal(self.conversation_id)
.await?
else {
self.clear_stopped_thread_goal_runtime_state().await;
return Ok(());
};
Expand Down Expand Up @@ -1237,7 +1262,11 @@ impl Session {
Arc::clone(&active_turn.turn_state)
};
let goal_is_current = match self.state_db_for_thread_goals().await {
Ok(Some(state_db)) => match state_db.get_thread_goal(self.conversation_id).await {
Ok(Some(state_db)) => match state_db
.thread_goals()
.get_thread_goal(self.conversation_id)
.await
{
Ok(Some(goal))
if goal.goal_id == candidate.goal_id
&& goal.status == codex_state::ThreadGoalStatus::Active =>
Expand Down Expand Up @@ -1333,7 +1362,11 @@ impl Session {
return None;
}
};
let goal = match state_db.get_thread_goal(self.conversation_id).await {
let goal = match state_db
.thread_goals()
.get_thread_goal(self.conversation_id)
.await
{
Ok(Some(goal)) => goal,
Ok(None) => {
tracing::debug!("skipping active goal continuation because no goal is set");
Expand Down
10 changes: 10 additions & 0 deletions codex-rs/core/src/session/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8592,6 +8592,7 @@ async fn budget_limited_accounting_steers_active_turn_without_aborting() -> anyh

let state_db = goal_test_state_db(sess.as_ref()).await?;
let goal = state_db
.thread_goals()
.get_thread_goal(sess.conversation_id)
.await?
.expect("goal should remain persisted after accounting");
Expand All @@ -8615,6 +8616,7 @@ async fn budget_limited_accounting_steers_active_turn_without_aborting() -> anyh
.await?;

let goal = state_db
.thread_goals()
.get_thread_goal(sess.conversation_id)
.await?
.expect("goal should remain persisted after follow-up accounting");
Expand Down Expand Up @@ -8654,6 +8656,7 @@ async fn external_goal_mutation_accounts_active_turn_before_status_change() -> a

let state_db = goal_test_state_db(sess.as_ref()).await?;
let goal = state_db
.thread_goals()
.get_thread_goal(sess.conversation_id)
.await?
.expect("goal should remain persisted");
Expand All @@ -8662,6 +8665,7 @@ async fn external_goal_mutation_accounts_active_turn_before_status_change() -> a
let previous_goal = goal.clone();
let goal_id = goal.goal_id.clone();
let updated_goal = state_db
.thread_goals()
.update_thread_goal(
sess.conversation_id,
codex_state::ThreadGoalUpdate {
Expand All @@ -8683,6 +8687,7 @@ async fn external_goal_mutation_accounts_active_turn_before_status_change() -> a

assert!(sess.active_turn.lock().await.is_some());
let goal = state_db
.thread_goals()
.get_thread_goal(sess.conversation_id)
.await?
.expect("goal should remain persisted");
Expand All @@ -8709,6 +8714,7 @@ async fn external_objective_change_steers_active_turn() -> anyhow::Result<()> {

let state_db = goal_test_state_db(sess.as_ref()).await?;
let old_goal = state_db
.thread_goals()
.replace_thread_goal(
sess.conversation_id,
"Keep improving the benchmark",
Expand All @@ -8717,6 +8723,7 @@ async fn external_objective_change_steers_active_turn() -> anyhow::Result<()> {
)
.await?;
let new_goal = state_db
.thread_goals()
.replace_thread_goal(
sess.conversation_id,
"Write a concise benchmark summary",
Expand Down Expand Up @@ -8774,6 +8781,7 @@ async fn external_active_goal_set_marks_current_turn_for_accounting() -> anyhow:

let state_db = goal_test_state_db(sess.as_ref()).await?;
let goal = state_db
.thread_goals()
.replace_thread_goal(
sess.conversation_id,
"Keep improving the benchmark",
Expand Down Expand Up @@ -8807,6 +8815,7 @@ async fn external_active_goal_set_marks_current_turn_for_accounting() -> anyhow:
.await?;

let goal = state_db
.thread_goals()
.get_thread_goal(sess.conversation_id)
.await?
.expect("goal should remain persisted");
Expand Down Expand Up @@ -8905,6 +8914,7 @@ async fn completed_goal_accounts_current_turn_tokens_before_tool_response() -> a
)
.await?;
let persisted_goal = state_db
.thread_goals()
.get_thread_goal(test.session_configured.thread_id)
.await?
.expect("goal should be persisted");
Expand Down
2 changes: 2 additions & 0 deletions codex-rs/core/src/thread_manager_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1561,6 +1561,7 @@ async fn resumed_thread_keeps_paused_goal_paused() -> anyhow::Result<()> {
.state_db()
.expect("source thread should have a state db");
state_db
.thread_goals()
.replace_thread_goal(
source.thread_id,
"Keep working until the task is done",
Expand All @@ -1581,6 +1582,7 @@ async fn resumed_thread_keeps_paused_goal_paused() -> anyhow::Result<()> {
.await
.expect("resume source thread");
let goal = state_db
.thread_goals()
.get_thread_goal(resumed.thread_id)
.await?
.expect("goal should still exist after resume");
Expand Down
1 change: 1 addition & 0 deletions codex-rs/state/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ pub use model::ThreadGoalStatus;
pub use model::ThreadMetadata;
pub use model::ThreadMetadataBuilder;
pub use model::ThreadsPage;
pub use runtime::GoalStore;
pub use runtime::RemoteControlEnrollmentRecord;
pub use runtime::ThreadFilterOptions;
pub use runtime::ThreadGoalAccountingMode;
Expand Down
8 changes: 7 additions & 1 deletion codex-rs/state/src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ use crate::apply_rollout_item;
use crate::migrations::runtime_logs_migrator;
use crate::migrations::runtime_state_migrator;
use crate::model::AgentJobRow;
use crate::model::ThreadGoalRow;
use crate::model::ThreadRow;
use crate::model::anchor_from_item;
use crate::model::datetime_to_epoch_millis;
Expand Down Expand Up @@ -65,6 +64,7 @@ mod remote_control;
mod test_support;
mod threads;

pub use goals::GoalStore;
pub use goals::ThreadGoalAccountingMode;
pub use goals::ThreadGoalAccountingOutcome;
pub use goals::ThreadGoalUpdate;
Expand All @@ -86,6 +86,7 @@ pub struct StateRuntime {
default_provider: String,
pool: Arc<sqlx::SqlitePool>,
logs_pool: Arc<sqlx::SqlitePool>,
thread_goals: GoalStore,
thread_updated_at_millis: Arc<AtomicI64>,
}

Expand Down Expand Up @@ -164,6 +165,7 @@ impl StateRuntime {
let thread_updated_at_millis = thread_updated_at_millis_result?;
let thread_updated_at_millis = thread_updated_at_millis.unwrap_or(0);
let runtime = Arc::new(Self {
thread_goals: GoalStore::new(Arc::clone(&pool)),
pool,
logs_pool,
codex_home,
Expand All @@ -183,6 +185,10 @@ impl StateRuntime {
pub fn codex_home(&self) -> &Path {
self.codex_home.as_path()
}

pub fn thread_goals(&self) -> &GoalStore {
&self.thread_goals
}
}

fn base_sqlite_options(path: &Path) -> SqliteConnectOptions {
Expand Down
Loading
Loading