diff --git a/crates/aft/src/semantic_index.rs b/crates/aft/src/semantic_index.rs index 30b731a8..44844497 100644 --- a/crates/aft/src/semantic_index.rs +++ b/crates/aft/src/semantic_index.rs @@ -23,7 +23,9 @@ use url::Url; const DEFAULT_DIMENSION: usize = 384; const MAX_ENTRIES: usize = 1_000_000; -const MAX_DIMENSION: usize = 1024; +// Covers high-dimensional backends such as OpenAI text-embedding-3-large (3072) +// and common local models (4096) while keeping a bounded supported shape. +const MAX_DIMENSION: usize = 4096; const F32_BYTES: usize = std::mem::size_of::(); const HEADER_BYTES_V1: usize = 9; const HEADER_BYTES_V2: usize = 13; @@ -176,6 +178,8 @@ fn validate_embedding_batch( return Ok(()); }; let expected_dimension = first_vector.len(); + validate_embedding_dimension(expected_dimension) + .map_err(|error| format!("{context} returned {error}"))?; for (index, vector) in vectors.iter().enumerate() { if vector.len() != expected_dimension { return Err(format!( @@ -188,6 +192,16 @@ fn validate_embedding_batch( Ok(()) } +fn validate_embedding_dimension(dimension: usize) -> Result<(), String> { + if dimension == 0 || dimension > MAX_DIMENSION { + return Err(format!( + "invalid embedding dimension: {dimension}; supported range is 1..={MAX_DIMENSION}" + )); + } + + Ok(()) +} + /// Normalize a base URL: validate scheme and strip trailing slash. /// Does NOT perform SSRF/private-IP validation — call /// `validate_base_url_no_ssrf` separately when processing user-supplied config. @@ -1837,9 +1851,7 @@ impl SemanticIndex { let dimension = read_u32(data, &mut pos)? as usize; let entry_count = read_u32(data, &mut pos)? as usize; - if dimension == 0 || dimension > MAX_DIMENSION { - return Err(format!("invalid embedding dimension: {}", dimension)); - } + validate_embedding_dimension(dimension)?; if entry_count > MAX_ENTRIES { return Err(format!("too many semantic index entries: {}", entry_count)); } diff --git a/crates/aft/tests/semantic_validation_test.rs b/crates/aft/tests/semantic_validation_test.rs index 5ca97765..a320f20b 100644 --- a/crates/aft/tests/semantic_validation_test.rs +++ b/crates/aft/tests/semantic_validation_test.rs @@ -1,5 +1,5 @@ use std::fs; -use std::path::PathBuf; +use std::path::{Path, PathBuf}; use aft::semantic_index::SemanticIndex; @@ -14,6 +14,22 @@ fn write_source_fixture(project_root: &std::path::Path) -> PathBuf { source_file } +fn build_empty_v6_bytes(dimension: u32) -> Vec { + let mut bytes = Vec::new(); + bytes.push(6u8); + bytes.extend_from_slice(&dimension.to_le_bytes()); + bytes.extend_from_slice(&0u32.to_le_bytes()); // entry_count + bytes.extend_from_slice(&0u32.to_le_bytes()); // fingerprint_len + bytes.extend_from_slice(&0u32.to_le_bytes()); // mtime_count + bytes +} + +fn unit_vector(dimension: usize, hot_index: usize) -> Vec { + let mut vector = vec![0.0; dimension]; + vector[hot_index] = 1.0; + vector +} + #[test] fn build_returns_backend_http_errors_verbatim() { let project = tempfile::tempdir().expect("create project dir"); @@ -77,6 +93,86 @@ fn build_returns_error_when_embedding_dimension_changes_across_batches() { ); } +#[test] +fn build_accepts_high_dimension_embeddings_and_search_roundtrips() { + let project = tempfile::tempdir().expect("create project dir"); + let source_file = write_source_fixture(project.path()); + let files = vec![source_file.clone()]; + let mut embed = |texts: Vec| { + Ok::>, String>( + texts + .into_iter() + .map(|text| { + if text.contains("handle_request") { + unit_vector(4096, 0) + } else if text.contains("normalize_user_id") { + unit_vector(4096, 1) + } else { + unit_vector(4096, 2) + } + }) + .collect(), + ) + }; + + let index = SemanticIndex::build(project.path(), &files, &mut embed, 16) + .expect("4096-dimensional build should be accepted"); + assert_eq!(index.dimension(), 4096); + + let storage = tempfile::tempdir().expect("create storage dir"); + index.write_to_disk(storage.path(), "high-dim-project"); + let restored = SemanticIndex::read_from_disk( + storage.path(), + "high-dim-project", + project.path(), + false, + None, + ) + .expect("restore 4096-dimensional semantic index"); + + let results = restored.search(&unit_vector(4096, 1), 1); + assert_eq!(results.len(), 1); + assert_eq!(results[0].name, "normalize_user_id"); + assert_eq!(results[0].file, source_file); +} + +#[test] +fn build_rejects_unsupported_embedding_dimensions() { + for dimension in [0usize, 4097] { + let project = tempfile::tempdir().expect("create project dir"); + let source_file = write_source_fixture(project.path()); + let files = vec![source_file]; + let mut embed = move |texts: Vec| { + Ok::>, String>(texts.into_iter().map(|_| vec![1.0; dimension]).collect()) + }; + + let error = SemanticIndex::build(project.path(), &files, &mut embed, 16) + .expect_err("unsupported dimensions should be rejected during build"); + assert!( + error.contains(&format!("invalid embedding dimension: {dimension}")) + && error.contains("supported range is 1..=4096"), + "error should include dimension and supported range: {error}" + ); + } +} + +#[test] +fn from_bytes_accepts_and_rejects_dimension_boundaries() { + let index = SemanticIndex::from_bytes(&build_empty_v6_bytes(4096), Path::new("/")) + .expect("4096 dimensions should deserialize"); + assert_eq!(index.dimension(), 4096); + + for dimension in [0u32, 4097] { + let error = SemanticIndex::from_bytes(&build_empty_v6_bytes(dimension), Path::new("/")) + .expect_err("unsupported dimension should be rejected"); + assert!( + error.contains(&format!("invalid embedding dimension: {dimension}")) + && error.contains("supported range is 1..=4096"), + "error should include supported range: {error}" + ); + } +} + #[test] fn build_returns_error_when_embedding_backend_returns_too_few_vectors() { let project = tempfile::tempdir().expect("create project dir");