Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions crates/aft/src/semantic_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ use url::Url;

const DEFAULT_DIMENSION: usize = 384;
const MAX_ENTRIES: usize = 1_000_000;
const MAX_DIMENSION: usize = 1024;
// Covers high-dimensional backends such as OpenAI text-embedding-3-large (3072)
// and common local models (4096) while keeping a bounded supported shape.
const MAX_DIMENSION: usize = 4096;
const F32_BYTES: usize = std::mem::size_of::<f32>();
const HEADER_BYTES_V1: usize = 9;
const HEADER_BYTES_V2: usize = 13;
Expand Down Expand Up @@ -176,6 +178,8 @@ fn validate_embedding_batch(
return Ok(());
};
let expected_dimension = first_vector.len();
validate_embedding_dimension(expected_dimension)
.map_err(|error| format!("{context} returned {error}"))?;
for (index, vector) in vectors.iter().enumerate() {
if vector.len() != expected_dimension {
return Err(format!(
Expand All @@ -188,6 +192,16 @@ fn validate_embedding_batch(
Ok(())
}

fn validate_embedding_dimension(dimension: usize) -> Result<(), String> {
if dimension == 0 || dimension > MAX_DIMENSION {
return Err(format!(
"invalid embedding dimension: {dimension}; supported range is 1..={MAX_DIMENSION}"
));
}

Ok(())
}

/// Normalize a base URL: validate scheme and strip trailing slash.
/// Does NOT perform SSRF/private-IP validation — call
/// `validate_base_url_no_ssrf` separately when processing user-supplied config.
Expand Down Expand Up @@ -1837,9 +1851,7 @@ impl SemanticIndex {

let dimension = read_u32(data, &mut pos)? as usize;
let entry_count = read_u32(data, &mut pos)? as usize;
if dimension == 0 || dimension > MAX_DIMENSION {
return Err(format!("invalid embedding dimension: {}", dimension));
}
validate_embedding_dimension(dimension)?;
if entry_count > MAX_ENTRIES {
return Err(format!("too many semantic index entries: {}", entry_count));
}
Expand Down
98 changes: 97 additions & 1 deletion crates/aft/tests/semantic_validation_test.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::fs;
use std::path::PathBuf;
use std::path::{Path, PathBuf};

use aft::semantic_index::SemanticIndex;

Expand All @@ -14,6 +14,22 @@ fn write_source_fixture(project_root: &std::path::Path) -> PathBuf {
source_file
}

fn build_empty_v6_bytes(dimension: u32) -> Vec<u8> {
let mut bytes = Vec::new();
bytes.push(6u8);
bytes.extend_from_slice(&dimension.to_le_bytes());
bytes.extend_from_slice(&0u32.to_le_bytes()); // entry_count
bytes.extend_from_slice(&0u32.to_le_bytes()); // fingerprint_len
bytes.extend_from_slice(&0u32.to_le_bytes()); // mtime_count
bytes
}

fn unit_vector(dimension: usize, hot_index: usize) -> Vec<f32> {
let mut vector = vec![0.0; dimension];
vector[hot_index] = 1.0;
vector
}

#[test]
fn build_returns_backend_http_errors_verbatim() {
let project = tempfile::tempdir().expect("create project dir");
Expand Down Expand Up @@ -77,6 +93,86 @@ fn build_returns_error_when_embedding_dimension_changes_across_batches() {
);
}

#[test]
fn build_accepts_high_dimension_embeddings_and_search_roundtrips() {
let project = tempfile::tempdir().expect("create project dir");
let source_file = write_source_fixture(project.path());
let files = vec![source_file.clone()];
let mut embed = |texts: Vec<String>| {
Ok::<Vec<Vec<f32>>, String>(
texts
.into_iter()
.map(|text| {
if text.contains("handle_request") {
unit_vector(4096, 0)
} else if text.contains("normalize_user_id") {
unit_vector(4096, 1)
} else {
unit_vector(4096, 2)
}
})
.collect(),
)
};

let index = SemanticIndex::build(project.path(), &files, &mut embed, 16)
.expect("4096-dimensional build should be accepted");
assert_eq!(index.dimension(), 4096);

let storage = tempfile::tempdir().expect("create storage dir");
index.write_to_disk(storage.path(), "high-dim-project");
let restored = SemanticIndex::read_from_disk(
storage.path(),
"high-dim-project",
project.path(),
false,
None,
)
.expect("restore 4096-dimensional semantic index");

let results = restored.search(&unit_vector(4096, 1), 1);
assert_eq!(results.len(), 1);
assert_eq!(results[0].name, "normalize_user_id");
assert_eq!(results[0].file, source_file);
}

#[test]
fn build_rejects_unsupported_embedding_dimensions() {
for dimension in [0usize, 4097] {
let project = tempfile::tempdir().expect("create project dir");
let source_file = write_source_fixture(project.path());
let files = vec![source_file];
let mut embed = move |texts: Vec<String>| {
Ok::<Vec<Vec<f32>>, String>(texts.into_iter().map(|_| vec![1.0; dimension]).collect())
};

let error = SemanticIndex::build(project.path(), &files, &mut embed, 16)
.expect_err("unsupported dimensions should be rejected during build");
assert!(
error.contains(&format!("invalid embedding dimension: {dimension}"))
&& error.contains("supported range is 1..=4096"),
"error should include dimension and supported range: {error}"
);
}
}

#[test]
fn from_bytes_accepts_and_rejects_dimension_boundaries() {
let index = SemanticIndex::from_bytes(&build_empty_v6_bytes(4096), Path::new("/"))
.expect("4096 dimensions should deserialize");
assert_eq!(index.dimension(), 4096);

for dimension in [0u32, 4097] {
let error = SemanticIndex::from_bytes(&build_empty_v6_bytes(dimension), Path::new("/"))
.expect_err("unsupported dimension should be rejected");
assert!(
error.contains(&format!("invalid embedding dimension: {dimension}"))
&& error.contains("supported range is 1..=4096"),
"error should include supported range: {error}"
);
}
}

#[test]
fn build_returns_error_when_embedding_backend_returns_too_few_vectors() {
let project = tempfile::tempdir().expect("create project dir");
Expand Down