Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
473 changes: 303 additions & 170 deletions codex-rs/core/src/client.rs

Large diffs are not rendered by default.

138 changes: 88 additions & 50 deletions codex-rs/core/src/codex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use crate::compact::should_use_remote_compact_task;
use crate::compact_remote::run_inline_remote_auto_compact_task;
use crate::connectors;
use crate::exec_policy::ExecPolicyManager;
use crate::features::FEATURES;
use crate::features::Feature;
use crate::features::Features;
use crate::features::maybe_push_unstable_features_warning;
Expand All @@ -32,7 +33,6 @@ use crate::stream_events_utils::handle_non_tool_response_item;
use crate::stream_events_utils::handle_output_item_done;
use crate::stream_events_utils::last_assistant_message_from_item;
use crate::terminal;
use crate::transport_manager::TransportManager;
use crate::truncate::TruncationPolicy;
use crate::turn_metadata::build_turn_metadata_header;
use crate::user_notification::UserNotifier;
Expand Down Expand Up @@ -488,7 +488,6 @@ pub(crate) struct Session {
#[derive(Debug)]
pub(crate) struct TurnContext {
pub(crate) sub_id: String,
pub(crate) client: ModelClient,
pub(crate) config: Arc<Config>,
pub(crate) auth_manager: Option<Arc<AuthManager>>,
pub(crate) model_info: ModelInfo,
Expand All @@ -497,7 +496,6 @@ pub(crate) struct TurnContext {
pub(crate) reasoning_effort: Option<ReasoningEffortConfig>,
pub(crate) reasoning_summary: ReasoningSummaryConfig,
pub(crate) session_source: SessionSource,
pub(crate) transport_manager: TransportManager,
/// The session's current working directory. All relative paths provided by
/// the model as well as sandbox policies are resolved against this path
/// instead of `std::env::current_dir()`.
Expand Down Expand Up @@ -681,6 +679,33 @@ pub(crate) struct SessionSettingsUpdate {
}

impl Session {
/// Builds the `x-codex-beta-features` header value for this session.
///
/// `ModelClient` is session-scoped and intentionally does not depend on the full `Config`, so
/// we precompute the comma-separated list of enabled experimental feature keys at session
/// creation time and thread it into the client.
fn build_model_client_beta_features_header(config: &Config) -> Option<String> {
let beta_features_header = FEATURES
.iter()
.filter_map(|spec| {
if spec.stage.experimental_menu_description().is_some()
&& config.features.enabled(spec.id)
{
Some(spec.key)
} else {
None
}
})
.collect::<Vec<_>>()
.join(",");

if beta_features_header.is_empty() {
None
} else {
Some(beta_features_header)
}
}

/// Don't expand the number of mutated arguments on config. We are in the process of getting rid of it.
pub(crate) fn build_per_turn_config(session_configuration: &SessionConfiguration) -> Config {
// todo(aibrahim): store this state somewhere else so we don't need to mut config
Expand Down Expand Up @@ -735,9 +760,7 @@ impl Session {
session_configuration: &SessionConfiguration,
per_turn_config: Config,
model_info: ModelInfo,
conversation_id: ThreadId,
sub_id: String,
transport_manager: TransportManager,
) -> TurnContext {
let reasoning_effort = session_configuration.collaboration_mode.reasoning_effort();
let reasoning_summary = session_configuration.model_reasoning_summary;
Expand All @@ -746,23 +769,10 @@ impl Session {
model_info.slug.as_str(),
);
let session_source = session_configuration.session_source.clone();
let auth_manager_for_context = auth_manager.clone();
let provider_for_context = provider.clone();
let transport_manager_for_context = transport_manager.clone();
let otel_manager_for_context = otel_manager.clone();
let auth_manager_for_context = auth_manager;
let provider_for_context = provider;
let otel_manager_for_context = otel_manager;
let per_turn_config = Arc::new(per_turn_config);
let client = ModelClient::new(
per_turn_config.clone(),
auth_manager,
model_info.clone(),
otel_manager,
provider,
reasoning_effort,
reasoning_summary,
conversation_id,
session_source.clone(),
transport_manager,
);

let tools_config = ToolsConfig::new(&ToolsConfigParams {
model_info: &model_info,
Expand All @@ -773,7 +783,6 @@ impl Session {
let cwd = session_configuration.cwd.clone();
TurnContext {
sub_id,
client,
config: per_turn_config.clone(),
auth_manager: auth_manager_for_context,
model_info: model_info.clone(),
Expand All @@ -782,7 +791,6 @@ impl Session {
reasoning_effort,
reasoning_summary,
session_source,
transport_manager: transport_manager_for_context,
cwd,
developer_instructions: session_configuration.developer_instructions.clone(),
compact_prompt: session_configuration.compact_prompt.clone(),
Expand Down Expand Up @@ -1020,7 +1028,17 @@ impl Session {
file_watcher,
agent_control,
state_db: state_db_ctx.clone(),
transport_manager: TransportManager::new(),
model_client: ModelClient::new(
Some(Arc::clone(&auth_manager)),
conversation_id,
session_configuration.provider.clone(),
session_configuration.session_source.clone(),
config.model_verbosity,
config.features.enabled(Feature::ResponsesWebsockets),
config.features.enabled(Feature::EnableRequestCompression),
config.features.enabled(Feature::RuntimeMetrics),
Self::build_model_client_beta_features_header(config.as_ref()),
),
};

let sess = Arc::new(Session {
Expand Down Expand Up @@ -1351,9 +1369,7 @@ impl Session {
&session_configuration,
per_turn_config,
model_info,
self.conversation_id,
sub_id,
self.services.transport_manager.clone(),
);

if let Some(final_schema) = final_output_json_schema {
Expand Down Expand Up @@ -3353,25 +3369,11 @@ async fn spawn_review_thread(
let reasoning_effort = per_turn_config.model_reasoning_effort;
let reasoning_summary = per_turn_config.model_reasoning_summary;
let session_source = parent_turn_context.session_source.clone();
let transport_manager = parent_turn_context.transport_manager.clone();

let per_turn_config = Arc::new(per_turn_config);
let client = ModelClient::new(
per_turn_config.clone(),
auth_manager,
model_info.clone(),
otel_manager,
provider,
reasoning_effort,
reasoning_summary,
sess.conversation_id,
session_source.clone(),
transport_manager.clone(),
);

let review_turn_context = TurnContext {
sub_id: sub_id.to_string(),
client,
config: per_turn_config,
auth_manager: auth_manager_for_context,
model_info: model_info.clone(),
Expand All @@ -3380,7 +3382,6 @@ async fn spawn_review_thread(
reasoning_effort,
reasoning_summary,
session_source,
transport_manager,
tools_config,
features: parent_turn_context.features.clone(),
ghost_snapshot: parent_turn_context.ghost_snapshot.clone(),
Expand Down Expand Up @@ -3605,7 +3606,9 @@ pub(crate) async fn run_turn(
let turn_diff_tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new()));

let turn_metadata_header = turn_context.resolve_turn_metadata_header().await;
let mut client_session = turn_context.client.new_session(turn_metadata_header);
// `ModelClientSession` is turn-scoped and caches WebSocket + sticky routing state, so we reuse
// one instance across retries within this turn.
let mut client_session = sess.services.model_client.new_session();

loop {
// Note that pending_input would be something like a message the user
Expand Down Expand Up @@ -3658,6 +3661,7 @@ pub(crate) async fn run_turn(
Arc::clone(&turn_context),
Arc::clone(&turn_diff_tracker),
&mut client_session,
turn_metadata_header.as_deref(),
sampling_request_input,
tool_selection,
cancellation_token.child_token(),
Expand Down Expand Up @@ -3844,6 +3848,7 @@ struct SamplingRequestToolSelection<'a> {
skill_name_counts_lower: &'a HashMap<String, usize>,
}

#[allow(clippy::too_many_arguments)]
#[instrument(level = "trace",
skip_all,
fields(
Expand All @@ -3857,6 +3862,7 @@ async fn run_sampling_request(
turn_context: Arc<TurnContext>,
turn_diff_tracker: SharedTurnDiffTracker,
client_session: &mut ModelClientSession,
turn_metadata_header: Option<&str>,
input: Vec<ResponseItem>,
tool_selection: SamplingRequestToolSelection<'_>,
cancellation_token: CancellationToken,
Expand Down Expand Up @@ -3914,6 +3920,7 @@ async fn run_sampling_request(
Arc::clone(&sess),
Arc::clone(&turn_context),
client_session,
turn_metadata_header,
Arc::clone(&turn_diff_tracker),
&prompt,
cancellation_token.child_token(),
Expand Down Expand Up @@ -3943,7 +3950,9 @@ async fn run_sampling_request(

// Use the configured provider-specific stream retry budget.
let max_retries = turn_context.provider.stream_max_retries();
if retries >= max_retries && client_session.try_switch_fallback_transport() {
if retries >= max_retries
&& client_session.try_switch_fallback_transport(&turn_context.otel_manager)
{
sess.send_event(
&turn_context,
EventMsg::Warning(WarningEvent {
Expand Down Expand Up @@ -4396,6 +4405,7 @@ async fn try_run_sampling_request(
sess: Arc<Session>,
turn_context: Arc<TurnContext>,
client_session: &mut ModelClientSession,
turn_metadata_header: Option<&str>,
turn_diff_tracker: SharedTurnDiffTracker,
prompt: &Prompt,
cancellation_token: CancellationToken,
Expand Down Expand Up @@ -4426,8 +4436,20 @@ async fn try_run_sampling_request(
);

sess.persist_rollout_items(&[rollout_item]).await;
let web_search_eligible = !matches!(
turn_context.config.web_search_mode,
Some(WebSearchMode::Disabled)
);
let mut stream = client_session
.stream(prompt)
.stream(
prompt,
&turn_context.model_info,
&turn_context.otel_manager,
turn_context.reasoning_effort,
turn_context.reasoning_summary,
web_search_eligible,
turn_metadata_header,
)
.instrument(trace_span!("stream_request"))
.or_cancel(&cancellation_token)
.await??;
Expand Down Expand Up @@ -5594,7 +5616,17 @@ mod tests {
file_watcher,
agent_control,
state_db: None,
transport_manager: TransportManager::new(),
model_client: ModelClient::new(
Some(auth_manager.clone()),
conversation_id,
session_configuration.provider.clone(),
session_configuration.session_source.clone(),
config.model_verbosity,
config.features.enabled(Feature::ResponsesWebsockets),
config.features.enabled(Feature::EnableRequestCompression),
config.features.enabled(Feature::RuntimeMetrics),
Session::build_model_client_beta_features_header(config.as_ref()),
),
};

let turn_context = Session::make_turn_context(
Expand All @@ -5604,9 +5636,7 @@ mod tests {
&session_configuration,
per_turn_config,
model_info,
conversation_id,
"turn_id".to_string(),
services.transport_manager.clone(),
);

let session = Session {
Expand Down Expand Up @@ -5716,7 +5746,17 @@ mod tests {
file_watcher,
agent_control,
state_db: None,
transport_manager: TransportManager::new(),
model_client: ModelClient::new(
Some(Arc::clone(&auth_manager)),
conversation_id,
session_configuration.provider.clone(),
session_configuration.session_source.clone(),
config.model_verbosity,
config.features.enabled(Feature::ResponsesWebsockets),
config.features.enabled(Feature::EnableRequestCompression),
config.features.enabled(Feature::RuntimeMetrics),
Session::build_model_client_beta_features_header(config.as_ref()),
),
};

let turn_context = Arc::new(Session::make_turn_context(
Expand All @@ -5726,9 +5766,7 @@ mod tests {
&session_configuration,
per_turn_config,
model_info,
conversation_id,
"turn_id".to_string(),
services.transport_manager.clone(),
));

let session = Arc::new(Session {
Expand Down
35 changes: 31 additions & 4 deletions codex-rs/core/src/compact.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::sync::Arc;

use crate::ModelProviderInfo;
use crate::Prompt;
use crate::client::ModelClientSession;
use crate::client_common::ResponseEvent;
use crate::codex::Session;
use crate::codex::TurnContext;
Expand All @@ -19,6 +20,7 @@ use crate::truncate::TruncationPolicy;
use crate::truncate::approx_token_count;
use crate::truncate::truncate_text;
use crate::util::backoff;
use codex_protocol::config_types::WebSearchMode;
use codex_protocol::items::ContextCompactionItem;
use codex_protocol::items::TurnItem;
use codex_protocol::models::ContentItem;
Expand Down Expand Up @@ -87,6 +89,10 @@ async fn run_compact_task_inner(

let max_retries = turn_context.provider.stream_max_retries();
let mut retries = 0;
let turn_metadata_header = turn_context.resolve_turn_metadata_header().await;
let mut client_session = sess.services.model_client.new_session();
// Reuse one client session so turn-scoped state (sticky routing, websocket append tracking)
// survives retries within this compact turn.

// TODO: If we need to guarantee the persisted mode always matches the prompt used for this
// turn, capture it in TurnContext at creation time. Using SessionConfiguration here avoids
Expand Down Expand Up @@ -119,7 +125,14 @@ async fn run_compact_task_inner(
personality: turn_context.personality,
..Default::default()
};
let attempt_result = drain_to_completed(&sess, turn_context.as_ref(), &prompt).await;
let attempt_result = drain_to_completed(
&sess,
turn_context.as_ref(),
&mut client_session,
turn_metadata_header.as_deref(),
&prompt,
)
.await;

match attempt_result {
Ok(()) => {
Expand Down Expand Up @@ -335,11 +348,25 @@ fn build_compacted_history_with_limit(
async fn drain_to_completed(
sess: &Session,
turn_context: &TurnContext,
client_session: &mut ModelClientSession,
turn_metadata_header: Option<&str>,
prompt: &Prompt,
) -> CodexResult<()> {
let turn_metadata_header = turn_context.resolve_turn_metadata_header().await;
let mut client_session = turn_context.client.new_session(turn_metadata_header);
let mut stream = client_session.stream(prompt).await?;
let web_search_eligible = !matches!(
turn_context.config.web_search_mode,
Some(WebSearchMode::Disabled)
);
let mut stream = client_session
.stream(
prompt,
&turn_context.model_info,
&turn_context.otel_manager,
turn_context.reasoning_effort,
turn_context.reasoning_summary,
web_search_eligible,
turn_metadata_header,
)
.await?;
loop {
let maybe_event = stream.next().await;
let Some(event) = maybe_event else {
Expand Down
Loading
Loading