Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 56 additions & 1 deletion codex-rs/core/src/state_db.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::config::Config;
use crate::features::Feature;
use crate::path_utils::normalize_for_path_comparison;
use crate::rollout::list::Cursor;
use crate::rollout::list::ThreadSortKey;
use crate::rollout::metadata;
Expand Down Expand Up @@ -156,6 +157,10 @@ fn cursor_to_anchor(cursor: Option<&Cursor>) -> Option<codex_state::Anchor> {
Some(codex_state::Anchor { ts, id })
}

fn normalize_cwd_for_state_db(cwd: &Path) -> PathBuf {
normalize_for_path_comparison(cwd).unwrap_or_else(|_| cwd.to_path_buf())
}

/// List thread ids from SQLite for parity checks without rollout scanning.
#[allow(clippy::too_many_arguments)]
pub async fn list_thread_ids_db(
Expand Down Expand Up @@ -355,7 +360,11 @@ pub async fn get_last_n_thread_memories_for_cwd(
stage: &str,
) -> Option<Vec<codex_state::ThreadMemory>> {
let ctx = context?;
match ctx.get_last_n_thread_memories_for_cwd(cwd, n).await {
let normalized_cwd = normalize_cwd_for_state_db(cwd);
match ctx
.get_last_n_thread_memories_for_cwd(&normalized_cwd, n)
.await
{
Ok(memories) => Some(memories),
Err(err) => {
warn!("state db get_last_n_thread_memories_for_cwd failed during {stage}: {err}");
Expand All @@ -364,6 +373,49 @@ pub async fn get_last_n_thread_memories_for_cwd(
}
}

/// Try to acquire or renew a per-cwd memory consolidation lock.
pub async fn try_acquire_memory_consolidation_lock(
context: Option<&codex_state::StateRuntime>,
cwd: &Path,
working_thread_id: ThreadId,
lease_seconds: i64,
stage: &str,
) -> Option<bool> {
let ctx = context?;
let normalized_cwd = normalize_cwd_for_state_db(cwd);
match ctx
.try_acquire_memory_consolidation_lock(&normalized_cwd, working_thread_id, lease_seconds)
.await
{
Ok(acquired) => Some(acquired),
Err(err) => {
warn!("state db try_acquire_memory_consolidation_lock failed during {stage}: {err}");
None
}
}
}

/// Release a per-cwd memory consolidation lock if held by `working_thread_id`.
pub async fn release_memory_consolidation_lock(
context: Option<&codex_state::StateRuntime>,
cwd: &Path,
working_thread_id: ThreadId,
stage: &str,
) -> Option<bool> {
let ctx = context?;
let normalized_cwd = normalize_cwd_for_state_db(cwd);
match ctx
.release_memory_consolidation_lock(&normalized_cwd, working_thread_id)
.await
{
Ok(released) => Some(released),
Err(err) => {
warn!("state db release_memory_consolidation_lock failed during {stage}: {err}");
None
}
}
}

/// Reconcile rollout items into SQLite, falling back to scanning the rollout file.
pub async fn reconcile_rollout(
context: Option<&codex_state::StateRuntime>,
Expand Down Expand Up @@ -400,6 +452,7 @@ pub async fn reconcile_rollout(
}
};
let mut metadata = outcome.metadata;
metadata.cwd = normalize_cwd_for_state_db(&metadata.cwd);
match archived_only {
Some(true) if metadata.archived_at.is_none() => {
metadata.archived_at = Some(metadata.updated_at);
Expand Down Expand Up @@ -447,6 +500,7 @@ pub async fn read_repair_rollout_path(
&& let Ok(Some(mut metadata)) = ctx.get_thread(thread_id).await
{
metadata.rollout_path = rollout_path.to_path_buf();
metadata.cwd = normalize_cwd_for_state_db(&metadata.cwd);
match archived_only {
Some(true) if metadata.archived_at.is_none() => {
metadata.archived_at = Some(metadata.updated_at);
Expand Down Expand Up @@ -509,6 +563,7 @@ pub async fn apply_rollout_items(
},
};
builder.rollout_path = rollout_path.to_path_buf();
builder.cwd = normalize_cwd_for_state_db(&builder.cwd);
if let Err(err) = ctx.apply_rollout_items(&builder, items, None).await {
warn!(
"state db apply_rollout_items failed during {stage} for {}: {err}",
Expand Down
8 changes: 8 additions & 0 deletions codex-rs/state/migrations/0009_memory_consolidation_locks.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
CREATE TABLE memory_consolidation_locks (
cwd TEXT PRIMARY KEY,
working_thread_id TEXT NOT NULL,
updated_at INTEGER NOT NULL
);

CREATE INDEX idx_memory_consolidation_locks_updated_at
ON memory_consolidation_locks(updated_at DESC);
143 changes: 143 additions & 0 deletions codex-rs/state/src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,64 @@ LIMIT ?
.collect()
}

/// Try to acquire or renew the per-cwd memory consolidation lock.
///
/// Returns `true` when the lock is acquired/renewed for `working_thread_id`.
/// Returns `false` when another owner holds a non-expired lease.
pub async fn try_acquire_memory_consolidation_lock(
&self,
cwd: &Path,
working_thread_id: ThreadId,
lease_seconds: i64,
) -> anyhow::Result<bool> {
let now = Utc::now().timestamp();
let stale_cutoff = now.saturating_sub(lease_seconds.max(0));
let result = sqlx::query(
r#"
INSERT INTO memory_consolidation_locks (
cwd,
working_thread_id,
updated_at
) VALUES (?, ?, ?)
ON CONFLICT(cwd) DO UPDATE SET
working_thread_id = excluded.working_thread_id,
updated_at = excluded.updated_at
WHERE memory_consolidation_locks.working_thread_id = excluded.working_thread_id
OR memory_consolidation_locks.updated_at <= ?
"#,
)
.bind(cwd.display().to_string())
.bind(working_thread_id.to_string())
.bind(now)
.bind(stale_cutoff)
.execute(self.pool.as_ref())
.await?;

Ok(result.rows_affected() > 0)
}

/// Release the per-cwd memory consolidation lock if held by `working_thread_id`.
///
/// Returns `true` when a lock row was removed.
pub async fn release_memory_consolidation_lock(
&self,
cwd: &Path,
working_thread_id: ThreadId,
) -> anyhow::Result<bool> {
let result = sqlx::query(
r#"
DELETE FROM memory_consolidation_locks
WHERE cwd = ? AND working_thread_id = ?
"#,
)
.bind(cwd.display().to_string())
.bind(working_thread_id.to_string())
.execute(self.pool.as_ref())
.await?;

Ok(result.rows_affected() > 0)
}

/// Persist dynamic tools for a thread if none have been stored yet.
///
/// Dynamic tools are defined at thread start and should not change afterward.
Expand Down Expand Up @@ -1328,6 +1386,91 @@ mod tests {
let _ = tokio::fs::remove_dir_all(codex_home).await;
}

#[tokio::test]
async fn memory_consolidation_lock_enforces_owner_and_release() {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
.await
.expect("initialize runtime");

let cwd = codex_home.join("workspace");
let owner_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
let owner_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");

assert!(
runtime
.try_acquire_memory_consolidation_lock(cwd.as_path(), owner_a, 600)
.await
.expect("acquire for owner_a"),
"owner_a should acquire lock"
);
assert!(
!runtime
.try_acquire_memory_consolidation_lock(cwd.as_path(), owner_b, 600)
.await
.expect("acquire for owner_b should fail"),
"owner_b should not steal active lock"
);
assert!(
runtime
.try_acquire_memory_consolidation_lock(cwd.as_path(), owner_a, 600)
.await
.expect("owner_a should renew lock"),
"owner_a should renew lock"
);
assert!(
!runtime
.release_memory_consolidation_lock(cwd.as_path(), owner_b)
.await
.expect("owner_b release should be no-op"),
"non-owner release should not remove lock"
);
assert!(
runtime
.release_memory_consolidation_lock(cwd.as_path(), owner_a)
.await
.expect("owner_a release"),
"owner_a should release lock"
);
assert!(
runtime
.try_acquire_memory_consolidation_lock(cwd.as_path(), owner_b, 600)
.await
.expect("owner_b acquire after release"),
"owner_b should acquire released lock"
);

let _ = tokio::fs::remove_dir_all(codex_home).await;
}

#[tokio::test]
async fn memory_consolidation_lock_can_be_stolen_when_lease_expired() {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string(), None)
.await
.expect("initialize runtime");

let cwd = codex_home.join("workspace");
let owner_a = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");
let owner_b = ThreadId::from_string(&Uuid::new_v4().to_string()).expect("thread id");

assert!(
runtime
.try_acquire_memory_consolidation_lock(cwd.as_path(), owner_a, 600)
.await
.expect("owner_a acquire")
);
assert!(
runtime
.try_acquire_memory_consolidation_lock(cwd.as_path(), owner_b, 0)
.await
.expect("owner_b steal with expired lease"),
"owner_b should steal lock when lease cutoff marks previous lock stale"
);

let _ = tokio::fs::remove_dir_all(codex_home).await;
}

#[tokio::test]
async fn deleting_thread_cascades_thread_memory() {
let codex_home = unique_temp_dir();
Expand Down
Loading