Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,29 @@ fn cmd_config_show(config: &EchoConfig) {
"auto"
}
);
println!();
println!(" Embedding Provider:");
println!(
" {:25} {:>15} {}",
"embedding_provider",
config.embedding_provider.to_string(),
source(
"SHRIMPK_EMBEDDING_PROVIDER",
fc.embedding_provider.is_some()
)
);
println!(
" {:25} {:>15} {}",
"embedding_model",
&config.embedding_model,
source("SHRIMPK_EMBEDDING_MODEL", fc.embedding_model.is_some())
);
println!(
" {:25} {:>15} {}",
"embedding_api_url",
&config.embedding_api_url,
source("SHRIMPK_EMBEDDING_API_URL", fc.embedding_api_url.is_some())
);
}

fn cmd_config_set(key: &str, value: &str) -> anyhow::Result<()> {
Expand All @@ -872,6 +895,11 @@ fn cmd_config_set(key: &str, value: &str) -> anyhow::Result<()> {
"enrichment_model" => fc.enrichment_model = Some(value.to_string()),
"consolidation_provider" => fc.consolidation_provider = Some(value.to_string()),
"max_facts_per_memory" => fc.max_facts_per_memory = Some(value.parse()?),
"embedding_provider" => {
fc.embedding_provider = Some(value.parse().map_err(|e: String| anyhow::anyhow!(e))?)
}
"embedding_model" => fc.embedding_model = Some(value.to_string()),
"embedding_api_url" => fc.embedding_api_url = Some(value.to_string()),
other => anyhow::bail!("Unknown config key: \"{other}\""),
}

Expand Down
233 changes: 233 additions & 0 deletions crates/shrimpk-core/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,45 @@ impl std::str::FromStr for RerankerBackend {
}
}

/// Backend for embedding vector generation.
///
/// Controls which embedding model and provider is used for memory storage
/// and echo queries. The default `Fastembed` backend uses a local ONNX model
/// (BGE-small-EN-v1.5, 384-dim) with zero external API calls.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum EmbeddingBackend {
/// Local fastembed ONNX model (default). Zero network calls.
/// Models: BGE-small-EN-v1.5 (384-dim), all-MiniLM-L6-v2 (384-dim), etc.
#[default]
Fastembed,
/// OpenAI-compatible embedding API (local or cloud).
/// Requires `embedding_api_url` and `embedding_model` to be set.
/// Works with: OpenAI, Ollama `/api/embeddings`, LiteLLM, vLLM, etc.
OpenAI,
}

impl std::fmt::Display for EmbeddingBackend {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Fastembed => write!(f, "fastembed"),
Self::OpenAI => write!(f, "openai"),
}
}
}

impl std::str::FromStr for EmbeddingBackend {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"fastembed" | "local" | "onnx" => Ok(Self::Fastembed),
"openai" | "api" | "ollama" => Ok(Self::OpenAI),
_ => Err(format!(
"invalid embedding provider '{s}': expected fastembed or openai"
)),
}
}
}

/// Quantization mode for embedding vectors in the echo index.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum QuantizationMode {
Expand Down Expand Up @@ -277,6 +316,21 @@ pub struct EchoConfig {
/// Embedding dimension for speech channel. Default: 640 (ECAPA-TDNN 256 + Whisper-tiny 384).
#[serde(default = "default_speech_dim")]
pub speech_embedding_dim: usize,

// --- Embedding provider (KS75) ---
/// Embedding backend: `Fastembed` (local ONNX, default) or `OpenAI` (API).
#[serde(default)]
pub embedding_provider: EmbeddingBackend,
/// Model name for the embedding provider.
/// Fastembed: "BGE-small-EN-v1.5" (default), "all-MiniLM-L6-v2", etc.
/// OpenAI: "text-embedding-3-small", "nomic-embed-text", Ollama model name, etc.
#[serde(default = "default_embedding_model")]
pub embedding_model: String,
/// API URL for OpenAI-compatible embedding providers.
/// Only used when `embedding_provider = OpenAI`.
/// Default: "http://127.0.0.1:11434" (Ollama).
#[serde(default = "default_embedding_api_url")]
pub embedding_api_url: String,
}

fn default_true() -> bool {
Expand Down Expand Up @@ -317,6 +371,13 @@ fn default_speech_dim() -> usize {
640
}

fn default_embedding_model() -> String {
"BGE-small-EN-v1.5".to_string()
}
fn default_embedding_api_url() -> String {
"http://127.0.0.1:11434".to_string()
}

fn default_proxy_target() -> String {
"http://127.0.0.1:11434".to_string()
}
Expand Down Expand Up @@ -411,6 +472,9 @@ impl Default for EchoConfig {
enabled_modalities: default_modalities(),
vision_embedding_dim: default_vision_dim(),
speech_embedding_dim: default_speech_dim(),
embedding_provider: EmbeddingBackend::default(),
embedding_model: default_embedding_model(),
embedding_api_url: default_embedding_api_url(),
}
}
}
Expand Down Expand Up @@ -476,6 +540,37 @@ impl EchoConfig {
}
}

/// Infer the embedding dimension from the configured model name.
///
/// Returns the known dimension for well-known models, or falls back to
/// `self.embedding_dim` (the explicitly configured value) if the model
/// is not recognized. This lets users set `embedding_model` without also
/// needing to manually set `embedding_dim`.
pub fn infer_embedding_dim(&self) -> usize {
match self.embedding_model.to_lowercase().as_str() {
// fastembed ONNX models
s if s.contains("bge-small") => 384,
s if s.contains("bge-base") => 768,
s if s.contains("bge-large") => 1024,
s if s.contains("bge-m3") => 1024,
s if s.contains("gte-large") => 1024,
s if s.contains("gte-base") => 768,
s if s.contains("minilm-l6") => 384,
s if s.contains("minilm-l12") => 384,
// OpenAI
s if s.contains("text-embedding-3-small") => 1536,
s if s.contains("text-embedding-3-large") => 3072,
s if s.contains("text-embedding-ada") => 1536,
// Ollama common models
s if s.contains("nomic-embed-text") => 768,
s if s.contains("mxbai-embed-large") => 1024,
s if s.contains("all-minilm") => 384,
s if s.contains("snowflake-arctic-embed") => 1024,
// Fallback to explicit config
_ => self.embedding_dim,
}
Comment thread
greptile-apps[bot] marked this conversation as resolved.
}

/// Estimated index size in bytes for the current config.
pub fn estimated_index_bytes(&self) -> u64 {
let bytes_per_entry = self.quantization.bytes_per_vector(self.embedding_dim) + 100;
Expand Down Expand Up @@ -537,6 +632,9 @@ pub struct FileConfig {
pub enabled_modalities: Option<Vec<crate::Modality>>,
pub vision_embedding_dim: Option<usize>,
pub speech_embedding_dim: Option<usize>,
pub embedding_provider: Option<EmbeddingBackend>,
pub embedding_model: Option<String>,
pub embedding_api_url: Option<String>,
}

/// Default data directory: `~/.shrimpk-kernel/`
Expand Down Expand Up @@ -642,6 +740,7 @@ pub fn resolve_config() -> crate::Result<EchoConfig> {
let mut config = EchoConfig::auto_detect();

// Layer 2: file overrides
let mut dim_set_by_file = false;
if let Some(fc) = load_config_file()? {
if let Some(v) = fc.max_memories {
config.max_memories = v;
Expand All @@ -663,6 +762,7 @@ pub fn resolve_config() -> crate::Result<EchoConfig> {
}
if let Some(v) = fc.embedding_dim {
config.embedding_dim = v;
dim_set_by_file = true;
}
if let Some(v) = fc.use_lsh {
config.use_lsh = v;
Expand Down Expand Up @@ -781,6 +881,15 @@ pub fn resolve_config() -> crate::Result<EchoConfig> {
if let Some(v) = fc.speech_embedding_dim {
config.speech_embedding_dim = v;
}
if let Some(v) = fc.embedding_provider {
config.embedding_provider = v;
}
if let Some(v) = fc.embedding_model {
config.embedding_model = v;
}
if let Some(v) = fc.embedding_api_url {
config.embedding_api_url = v;
}
}

// Layer 3: env var overrides (highest priority)
Expand Down Expand Up @@ -844,6 +953,26 @@ pub fn resolve_config() -> crate::Result<EchoConfig> {
config.hebbian_prune_threshold = v;
}

if let Ok(v) = std::env::var("SHRIMPK_EMBEDDING_PROVIDER")
&& let Ok(provider) = v.parse::<EmbeddingBackend>()
{
config.embedding_provider = provider;
}
if let Ok(v) = std::env::var("SHRIMPK_EMBEDDING_MODEL") {
config.embedding_model = v;
}
if let Ok(v) = std::env::var("SHRIMPK_EMBEDDING_API_URL") {
config.embedding_api_url = v;
}

// Auto-infer embedding_dim from model name unless explicitly overridden
// by either env var (Layer 3) or config file (Layer 2).
if let Some(v) = env_usize("SHRIMPK_EMBEDDING_DIM")? {
config.embedding_dim = v;
} else if !dim_set_by_file {
config.embedding_dim = config.infer_embedding_dim();
}

// Backward compatibility: if reranker_enabled=true but backend=None, default to Llm
if config.reranker_enabled && config.reranker_backend == RerankerBackend::None {
config.reranker_backend = RerankerBackend::Llm;
Expand Down Expand Up @@ -1265,4 +1394,108 @@ mod tests {
"Explicit backend should override legacy reranker_enabled"
);
}

// --- KS75: EmbeddingBackend ---

#[test]
fn embedding_provider_default_is_fastembed() {
let config = EchoConfig::default();
assert_eq!(config.embedding_provider, EmbeddingBackend::Fastembed);
assert_eq!(config.embedding_model, "BGE-small-EN-v1.5");
}

#[test]
fn embedding_provider_parse_roundtrip() {
for (input, expected) in [
("fastembed", EmbeddingBackend::Fastembed),
("local", EmbeddingBackend::Fastembed),
("onnx", EmbeddingBackend::Fastembed),
("openai", EmbeddingBackend::OpenAI),
("api", EmbeddingBackend::OpenAI),
("ollama", EmbeddingBackend::OpenAI),
] {
let parsed: EmbeddingBackend = input.parse().unwrap();
assert_eq!(parsed, expected, "parsing '{input}'");
}
}

#[test]
fn embedding_provider_parse_invalid() {
assert!("unknown".parse::<EmbeddingBackend>().is_err());
}

#[test]
fn embedding_provider_display() {
assert_eq!(EmbeddingBackend::Fastembed.to_string(), "fastembed");
assert_eq!(EmbeddingBackend::OpenAI.to_string(), "openai");
}

#[test]
fn infer_embedding_dim_known_models() {
let cases: &[(&str, usize)] = &[
("BGE-small-EN-v1.5", 384),
("BGE-base-EN-v1.5", 768),
("BGE-large-EN-v1.5", 1024),
("bge-m3", 1024),
("gte-large-en-v1.5", 1024),
("gte-base-en-v1.5", 768),
("text-embedding-3-small", 1536),
("text-embedding-3-large", 3072),
("nomic-embed-text", 768),
];
for &(model, expected_dim) in cases {
let config = EchoConfig {
embedding_model: model.into(),
..Default::default()
};
assert_eq!(
config.infer_embedding_dim(),
expected_dim,
"model '{model}'"
);
}
}

#[test]
fn infer_embedding_dim_unknown_falls_back() {
let config = EchoConfig {
embedding_model: "my-custom-model".into(),
embedding_dim: 512,
..Default::default()
};
assert_eq!(
config.infer_embedding_dim(),
512,
"Unknown model should fall back to embedding_dim"
);
}

#[test]
fn embedding_provider_serde_roundtrip() {
let config = EchoConfig {
embedding_provider: EmbeddingBackend::OpenAI,
embedding_model: "text-embedding-3-small".into(),
embedding_api_url: "https://api.openai.com".into(),
..Default::default()
};
let json = serde_json::to_string(&config).unwrap();
let parsed: EchoConfig = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.embedding_provider, EmbeddingBackend::OpenAI);
assert_eq!(parsed.embedding_model, "text-embedding-3-small");
assert_eq!(parsed.embedding_api_url, "https://api.openai.com");
}

#[test]
fn file_config_embedding_fields_toml_roundtrip() {
let fc = FileConfig {
embedding_provider: Some(EmbeddingBackend::OpenAI),
embedding_model: Some("nomic-embed-text".into()),
embedding_api_url: Some("http://localhost:11434".into()),
..Default::default()
};
let toml_str = toml::to_string_pretty(&fc).unwrap();
let parsed: FileConfig = toml::from_str(&toml_str).unwrap();
assert_eq!(parsed.embedding_provider, Some(EmbeddingBackend::OpenAI));
assert_eq!(parsed.embedding_model, Some("nomic-embed-text".into()));
}
}
7 changes: 4 additions & 3 deletions crates/shrimpk-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ pub mod traits;

// Re-export commonly used types at crate root
pub use config::{
EchoConfig, FileConfig, QuantizationMode, RerankerBackend, config_dir, config_path, disk_usage,
load_config_file, resolve_config, save_config_file,
EchoConfig, EmbeddingBackend, FileConfig, QuantizationMode, RerankerBackend, config_dir,
config_path, disk_usage, load_config_file, resolve_config, save_config_file,
};
pub use entity::{EntityFrame, EntityId, EntityKind};
pub use error::{Result, ShrimPKError};
Expand All @@ -26,5 +26,6 @@ pub use memory::{
};
pub use pii::{PiiMatch, PiiType};
pub use traits::{
ConsolidationOutput, Consolidator, ExtractedFact, FactType, LabelSet, ModelBackend, Provider,
ConsolidationOutput, Consolidator, EmbeddingProvider, ExtractedFact, FactType, LabelSet,
ModelBackend, Provider,
};
Loading
Loading