Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions crates/zeph-core/src/agent/mcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -414,10 +414,21 @@ impl<C: Channel> Agent<C> {
return;
}
let provider = self.embedding_provider.clone();
let embed_fn = |text: &str| -> zeph_mcp::registry::EmbedFuture {
let embed_timeout = std::time::Duration::from_secs(self.runtime.timeouts.embedding_seconds);
let embed_fn = move |text: &str| -> zeph_mcp::registry::EmbedFuture {
let owned = text.to_owned();
let p = provider.clone();
Box::pin(async move { p.embed(&owned).await })
Box::pin(async move {
if let Ok(result) = tokio::time::timeout(embed_timeout, p.embed(&owned)).await {
result
} else {
tracing::warn!(
timeout_secs = embed_timeout.as_secs(),
"MCP registry: embedding timed out"
);
Err(zeph_llm::LlmError::Timeout)
}
})
};
if let Err(e) = registry
.sync(&self.mcp.tools, &self.skill_state.embedding_model, embed_fn)
Expand Down Expand Up @@ -586,7 +597,22 @@ impl<C: Channel> Agent<C> {
.clone()
.unwrap_or_else(|| self.embedding_provider.clone());

let embed_fn = provider.embed_fn();
let inner_embed = provider.embed_fn();
let embed_timeout = std::time::Duration::from_secs(self.runtime.timeouts.embedding_seconds);
let embed_fn = move |text: &str| -> zeph_llm::provider::EmbedFuture {
let fut = inner_embed(text);
Box::pin(async move {
if let Ok(result) = tokio::time::timeout(embed_timeout, fut).await {
result
} else {
tracing::warn!(
timeout_secs = embed_timeout.as_secs(),
"semantic index: embedding probe timed out"
);
Err(zeph_llm::LlmError::Timeout)
}
})
};

match zeph_mcp::SemanticToolIndex::build(&self.mcp.tools, &embed_fn).await {
Ok(idx) => {
Expand Down
15 changes: 13 additions & 2 deletions crates/zeph-core/src/agent/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1794,10 +1794,21 @@ impl<C: Channel> Agent<C> {
/// Rebuild or sync the in-memory skill matcher and BM25 index after a registry update.
async fn rebuild_skill_matcher(&mut self, all_meta: &[&zeph_skills::loader::SkillMeta]) {
let provider = self.embedding_provider.clone();
let embed_fn = |text: &str| -> zeph_skills::matcher::EmbedFuture {
let embed_timeout = std::time::Duration::from_secs(self.runtime.timeouts.embedding_seconds);
let embed_fn = move |text: &str| -> zeph_skills::matcher::EmbedFuture {
let owned = text.to_owned();
let p = provider.clone();
Box::pin(async move { p.embed(&owned).await })
Box::pin(async move {
if let Ok(result) = tokio::time::timeout(embed_timeout, p.embed(&owned)).await {
result
} else {
tracing::warn!(
timeout_secs = embed_timeout.as_secs(),
"skill matcher: embedding timed out"
);
Err(zeph_llm::LlmError::Timeout)
}
})
};

let needs_inmemory_rebuild = !self
Expand Down
17 changes: 16 additions & 1 deletion crates/zeph-core/src/bootstrap/mcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,22 @@ pub async fn create_mcp_registry(
return None;
};
let mut reg = zeph_mcp::McpToolRegistry::with_ops(ops.clone());
let embed_fn = provider.embed_fn();
let inner_embed = provider.embed_fn();
let embed_timeout = std::time::Duration::from_secs(config.timeouts.embedding_seconds);
let embed_fn = move |text: &str| -> zeph_llm::provider::EmbedFuture {
let fut = inner_embed(text);
Box::pin(async move {
if let Ok(result) = tokio::time::timeout(embed_timeout, fut).await {
result
} else {
tracing::warn!(
timeout_secs = embed_timeout.as_secs(),
"MCP tool embedding timed out, skipping tool"
);
Err(zeph_llm::LlmError::Timeout)
}
})
};
if let Err(e) = reg.sync(mcp_tools, embedding_model, &embed_fn).await {
tracing::warn!("MCP tool embedding sync failed: {e:#}");
}
Expand Down
17 changes: 16 additions & 1 deletion crates/zeph-core/src/bootstrap/skills.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,22 @@ pub async fn create_skill_matcher(
embedding_model: &str,
qdrant_ops: Option<&QdrantOps>,
) -> Option<SkillMatcherBackend> {
let embed_fn = provider.embed_fn();
let inner_embed = provider.embed_fn();
let embed_timeout = std::time::Duration::from_secs(config.timeouts.embedding_seconds);
let embed_fn = move |text: &str| -> zeph_llm::provider::EmbedFuture {
let fut = inner_embed(text);
Box::pin(async move {
if let Ok(result) = tokio::time::timeout(embed_timeout, fut).await {
result
} else {
tracing::warn!(
timeout_secs = embed_timeout.as_secs(),
"skill matcher: embedding probe timed out"
);
Err(zeph_llm::LlmError::Timeout)
}
})
};

if config.memory.semantic.enabled
&& memory.is_vector_store_connected().await
Expand Down
23 changes: 23 additions & 0 deletions crates/zeph-llm/src/mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ pub struct MockProvider {
pub embed_invalid_input: bool,
/// Tracks how many times `embed()` was called. Useful for verifying embed reuse.
pub embed_call_count: Arc<std::sync::atomic::AtomicU64>,
/// Milliseconds to sleep inside `embed()` before returning. Used to simulate slow providers.
pub embed_delay_ms: u64,
}

impl Default for MockProvider {
Expand All @@ -60,6 +62,7 @@ impl Default for MockProvider {
name_override: None,
embed_invalid_input: false,
embed_call_count: Arc::new(std::sync::atomic::AtomicU64::new(0)),
embed_delay_ms: 0,
}
}
}
Expand Down Expand Up @@ -119,6 +122,23 @@ impl MockProvider {
self
}

/// Enable embedding support with a fixed return vector.
#[must_use]
pub fn with_embedding(mut self, embedding: Vec<f32>) -> Self {
self.embedding = embedding;
self.supports_embeddings = true;
self
}

/// Make `embed()` sleep for `ms` milliseconds before returning.
/// Useful for testing timeout behaviour.
#[must_use]
pub fn with_embed_delay(mut self, ms: u64) -> Self {
self.embed_delay_ms = ms;
self.supports_embeddings = true;
self
}

/// Enable call recording. Returns the shared buffer. Each `chat()` call
/// appends a clone of the messages slice so tests can inspect them.
#[must_use]
Expand Down Expand Up @@ -202,6 +222,9 @@ impl LlmProvider for MockProvider {
async fn embed(&self, _text: &str) -> Result<Vec<f32>, crate::LlmError> {
self.embed_call_count
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
if self.embed_delay_ms > 0 {
tokio::time::sleep(std::time::Duration::from_millis(self.embed_delay_ms)).await;
}
if let Ok(mut errors) = self.errors.lock()
&& !errors.is_empty()
{
Expand Down
9 changes: 8 additions & 1 deletion src/daemon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,14 @@ pub(crate) async fn run_daemon(

// Pre-resolve RL embed dim before embedding_provider is moved into the agent builder.
let rl_embed_dim_resolved = if config.skills.rl_routing_enabled {
Some(crate::runner::resolve_rl_embed_dim(&config.skills, &embedding_provider).await)
Some(
crate::runner::resolve_rl_embed_dim(
&config.skills,
&embedding_provider,
config.timeouts.embedding_seconds,
)
.await,
)
} else {
None
};
Expand Down
58 changes: 53 additions & 5 deletions src/runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1328,7 +1328,14 @@ pub(crate) async fn run(cli: Cli) -> anyhow::Result<()> {

// Pre-resolve RL embed dim before embedding_provider is moved into the agent builder.
let rl_embed_dim_resolved = if config.skills.rl_routing_enabled {
Some(resolve_rl_embed_dim(&config.skills, &embedding_provider).await)
Some(
resolve_rl_embed_dim(
&config.skills,
&embedding_provider,
config.timeouts.embedding_seconds,
)
.await,
)
} else {
None
};
Expand Down Expand Up @@ -2183,21 +2190,36 @@ pub(crate) async fn load_rl_head(
pub(crate) async fn resolve_rl_embed_dim(
skills_config: &zeph_core::config::SkillsConfig,
embedding_provider: &LlmAnyProvider,
embedding_timeout_secs: u64,
) -> usize {
const FALLBACK: usize = 1536;
if let Some(dim) = skills_config.rl_embed_dim {
return dim;
}
match embedding_provider.embed(" ").await {
Ok(v) if !v.is_empty() => v.len(),
Ok(_) | Err(_) => {
const FALLBACK: usize = 1536;
let probe = tokio::time::timeout(
std::time::Duration::from_secs(embedding_timeout_secs),
embedding_provider.embed(" "),
)
.await;
match probe {
Ok(Ok(v)) if !v.is_empty() => v.len(),
Ok(Ok(_) | Err(_)) => {
tracing::warn!(
fallback = FALLBACK,
"rl_head: could not probe embedding dimension from provider; \
set `skills.rl_embed_dim` in config to avoid this fallback"
);
FALLBACK
}
Err(_) => {
tracing::warn!(
timeout_secs = embedding_timeout_secs,
fallback = FALLBACK,
"rl_head: embedding probe timed out; \
set `skills.rl_embed_dim` in config to avoid this fallback"
);
FALLBACK
}
}
}

Expand Down Expand Up @@ -2598,4 +2620,30 @@ mod tests {
config.acp.transport = AcpTransport::Http;
assert!(configured_acp_autostart_transport(&config, &cli).is_none());
}

// --- resolve_rl_embed_dim ---

/// A slow embed (1100 ms) cut off by a 1-second timeout must fall back to 1536.
#[tokio::test]
async fn resolve_rl_embed_dim_timeout_uses_fallback() {
use zeph_llm::mock::MockProvider;
let config = zeph_core::Config::default();
// 1100 ms delay > 1 s timeout → guaranteed to trigger, 100 ms safety margin
let provider =
zeph_llm::any::AnyProvider::Mock(MockProvider::default().with_embed_delay(1100));
let dim = resolve_rl_embed_dim(&config.skills, &provider, 1).await;
assert_eq!(dim, 1536);
}

/// A fast embed returning a 768-dim vector must be returned unchanged.
#[tokio::test]
async fn resolve_rl_embed_dim_fast_provider_returns_dim() {
use zeph_llm::mock::MockProvider;
let config = zeph_core::Config::default();
let provider = zeph_llm::any::AnyProvider::Mock(
MockProvider::default().with_embedding(vec![0.0f32; 768]),
);
let dim = resolve_rl_embed_dim(&config.skills, &provider, 30).await;
assert_eq!(dim, 768);
}
}
Loading