Skip to content

Commit

Permalink
Merge #4509
Browse files Browse the repository at this point in the history
4509: Rest embedder r=ManyTheFish a=dureuill

Fixes #4531 

See [Usage page](https://meilisearch.notion.site/v1-8-AI-search-API-usage-135552d6e85a4a52bc7109be82aeca42?pvs=25#e6f58c3b742c4effb4ddc625ce12ee16)

### Implementation changes

- Remove tokio, futures, reqwests
- Add a new `milli::vector::rest::Embedder` embedder
- Update OpenAI and Ollama embedders to use the REST embedder internally
- Make Embedder::embed a sync method
- Add the new embedder source as described in the usage


Co-authored-by: Louis Dureuil <louis@meilisearch.com>
  • Loading branch information
meili-bors[bot] and dureuill committed Mar 27, 2024
2 parents 5ea017b + dfa5e41 commit 34dfea7
Show file tree
Hide file tree
Showing 18 changed files with 1,032 additions and 748 deletions.
5 changes: 2 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions meilisearch-types/src/error.rs
Expand Up @@ -353,6 +353,7 @@ impl ErrorCode for milli::Error {
| UserError::InvalidOpenAiModelDimensions { .. }
| UserError::InvalidOpenAiModelDimensionsMax { .. }
| UserError::InvalidSettingsDimensions { .. }
| UserError::InvalidUrl { .. }
| UserError::InvalidPrompt(_) => Code::InvalidSettingsEmbedders,
UserError::TooManyEmbedders(_) => Code::InvalidSettingsEmbedders,
UserError::InvalidPromptForEmbeddings(..) => Code::InvalidSettingsEmbedders,
Expand Down
7 changes: 3 additions & 4 deletions meilisearch/src/routes/indexes/search.rs
Expand Up @@ -202,7 +202,7 @@ pub async fn search_with_url_query(
let index = index_scheduler.index(&index_uid)?;
let features = index_scheduler.features();

let distribution = embed(&mut query, index_scheduler.get_ref(), &index).await?;
let distribution = embed(&mut query, index_scheduler.get_ref(), &index)?;

let search_result =
tokio::task::spawn_blocking(move || perform_search(&index, query, features, distribution))
Expand Down Expand Up @@ -241,7 +241,7 @@ pub async fn search_with_post(

let features = index_scheduler.features();

let distribution = embed(&mut query, index_scheduler.get_ref(), &index).await?;
let distribution = embed(&mut query, index_scheduler.get_ref(), &index)?;

let search_result =
tokio::task::spawn_blocking(move || perform_search(&index, query, features, distribution))
Expand All @@ -260,7 +260,7 @@ pub async fn search_with_post(
Ok(HttpResponse::Ok().json(search_result))
}

pub async fn embed(
pub fn embed(
query: &mut SearchQuery,
index_scheduler: &IndexScheduler,
index: &milli::Index,
Expand All @@ -287,7 +287,6 @@ pub async fn embed(

let embeddings = embedder
.embed(vec![q.to_owned()])
.await
.map_err(milli::vector::Error::from)
.map_err(milli::Error::from)?
.pop()
Expand Down
1 change: 1 addition & 0 deletions meilisearch/src/routes/indexes/settings.rs
Expand Up @@ -605,6 +605,7 @@ fn embedder_analytics(
EmbedderSource::HuggingFace => sources.insert("huggingFace"),
EmbedderSource::UserProvided => sources.insert("userProvided"),
EmbedderSource::Ollama => sources.insert("ollama"),
EmbedderSource::Rest => sources.insert("rest"),
};
}
};
Expand Down
5 changes: 2 additions & 3 deletions meilisearch/src/routes/multi_search.rs
Expand Up @@ -75,9 +75,8 @@ pub async fn multi_search_with_post(
})
.with_index(query_index)?;

let distribution = embed(&mut query, index_scheduler.get_ref(), &index)
.await
.with_index(query_index)?;
let distribution =
embed(&mut query, index_scheduler.get_ref(), &index).with_index(query_index)?;

let search_result = tokio::task::spawn_blocking(move || {
perform_search(&index, query, features, distribution)
Expand Down
8 changes: 2 additions & 6 deletions milli/Cargo.toml
Expand Up @@ -80,17 +80,13 @@ tokenizers = { git = "https://github.com/huggingface/tokenizers.git", tag = "v0.
hf-hub = { git = "https://github.com/dureuill/hf-hub.git", branch = "rust_tls", default_features = false, features = [
"online",
] }
tokio = { version = "1.35.1", features = ["rt"] }
futures = "0.3.30"
reqwest = { version = "0.11.23", features = [
"rustls-tls",
"json",
], default-features = false }
tiktoken-rs = "0.5.8"
liquid = "0.26.4"
arroy = "0.2.0"
rand = "0.8.5"
tracing = "0.1.40"
ureq = { version = "2.9.6", features = ["json"] }
url = "2.5.0"

[dev-dependencies]
mimalloc = { version = "0.1.39", default-features = false }
Expand Down
2 changes: 2 additions & 0 deletions milli/src/error.rs
Expand Up @@ -243,6 +243,8 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco
},
#[error("`.embedders.{embedder_name}.dimensions`: `dimensions` cannot be zero")]
InvalidSettingsDimensions { embedder_name: String },
#[error("`.embedders.{embedder_name}.url`: could not parse `{url}`: {inner_error}")]
InvalidUrl { embedder_name: String, inner_error: url::ParseError, url: String },
}

impl From<crate::vector::Error> for Error {
Expand Down
Expand Up @@ -339,6 +339,7 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
prompt_reader: grenad::Reader<R>,
indexer: GrenadParameters,
embedder: Arc<Embedder>,
request_threads: &rayon::ThreadPool,
) -> Result<grenad::Reader<BufReader<File>>> {
puffin::profile_function!();
let n_chunks = embedder.chunk_count_hint(); // chunk level parallelism
Expand Down Expand Up @@ -376,7 +377,10 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(

if chunks.len() == chunks.capacity() {
let chunked_embeds = embedder
.embed_chunks(std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks)))
.embed_chunks(
std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks)),
request_threads,
)
.map_err(crate::vector::Error::from)
.map_err(crate::Error::from)?;

Expand All @@ -394,7 +398,7 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
// send last chunk
if !chunks.is_empty() {
let chunked_embeds = embedder
.embed_chunks(std::mem::take(&mut chunks))
.embed_chunks(std::mem::take(&mut chunks), request_threads)
.map_err(crate::vector::Error::from)
.map_err(crate::Error::from)?;
for (docid, embeddings) in chunks_ids
Expand All @@ -408,7 +412,7 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(

if !current_chunk.is_empty() {
let embeds = embedder
.embed_chunks(vec![std::mem::take(&mut current_chunk)])
.embed_chunks(vec![std::mem::take(&mut current_chunk)], request_threads)
.map_err(crate::vector::Error::from)
.map_err(crate::Error::from)?;

Expand Down
13 changes: 12 additions & 1 deletion milli/src/update/index_documents/extract/mod.rs
Expand Up @@ -238,6 +238,12 @@ fn send_original_documents_data(

let documents_chunk_cloned = original_documents_chunk.clone();
let lmdb_writer_sx_cloned = lmdb_writer_sx.clone();

let request_threads = rayon::ThreadPoolBuilder::new()
.num_threads(crate::vector::REQUEST_PARALLELISM)
.thread_name(|index| format!("embedding-request-{index}"))
.build()?;

rayon::spawn(move || {
for (name, (embedder, prompt)) in embedders {
let result = extract_vector_points(
Expand All @@ -249,7 +255,12 @@ fn send_original_documents_data(
);
match result {
Ok(ExtractedVectorPoints { manual_vectors, remove_vectors, prompts }) => {
let embeddings = match extract_embeddings(prompts, indexer, embedder.clone()) {
let embeddings = match extract_embeddings(
prompts,
indexer,
embedder.clone(),
&request_threads,
) {
Ok(results) => Some(results),
Err(error) => {
let _ = lmdb_writer_sx_cloned.send(Err(error));
Expand Down
6 changes: 6 additions & 0 deletions milli/src/update/index_documents/mod.rs
Expand Up @@ -2646,6 +2646,12 @@ mod tests {
api_key: Setting::NotSet,
dimensions: Setting::Set(3),
document_template: Setting::NotSet,
url: Setting::NotSet,
query: Setting::NotSet,
input_field: Setting::NotSet,
path_to_embeddings: Setting::NotSet,
embedding_object: Setting::NotSet,
input_type: Setting::NotSet,
}),
);
settings.set_embedder_settings(embedders);
Expand Down
82 changes: 80 additions & 2 deletions milli/src/update/settings.rs
Expand Up @@ -1140,6 +1140,12 @@ fn validate_prompt(
api_key,
dimensions,
document_template: Setting::Set(template),
url,
query,
input_field,
path_to_embeddings,
embedding_object,
input_type,
}) => {
// validate
let template = crate::prompt::Prompt::new(template)
Expand All @@ -1153,6 +1159,12 @@ fn validate_prompt(
api_key,
dimensions,
document_template: Setting::Set(template),
url,
query,
input_field,
path_to_embeddings,
embedding_object,
input_type,
}))
}
new => Ok(new),
Expand All @@ -1165,8 +1177,20 @@ pub fn validate_embedding_settings(
) -> Result<Setting<EmbeddingSettings>> {
let settings = validate_prompt(name, settings)?;
let Setting::Set(settings) = settings else { return Ok(settings) };
let EmbeddingSettings { source, model, revision, api_key, dimensions, document_template } =
settings;
let EmbeddingSettings {
source,
model,
revision,
api_key,
dimensions,
document_template,
url,
query,
input_field,
path_to_embeddings,
embedding_object,
input_type,
} = settings;

if let Some(0) = dimensions.set() {
return Err(crate::error::UserError::InvalidSettingsDimensions {
Expand All @@ -1175,6 +1199,14 @@ pub fn validate_embedding_settings(
.into());
}

if let Some(url) = url.as_ref().set() {
url::Url::parse(url).map_err(|error| crate::error::UserError::InvalidUrl {
embedder_name: name.to_owned(),
inner_error: error,
url: url.to_owned(),
})?;
}

let Some(inferred_source) = source.set() else {
return Ok(Setting::Set(EmbeddingSettings {
source,
Expand All @@ -1183,11 +1215,25 @@ pub fn validate_embedding_settings(
api_key,
dimensions,
document_template,
url,
query,
input_field,
path_to_embeddings,
embedding_object,
input_type,
}));
};
match inferred_source {
EmbedderSource::OpenAi => {
check_unset(&revision, "revision", inferred_source, name)?;

check_unset(&url, "url", inferred_source, name)?;
check_unset(&query, "query", inferred_source, name)?;
check_unset(&input_field, "inputField", inferred_source, name)?;
check_unset(&path_to_embeddings, "pathToEmbeddings", inferred_source, name)?;
check_unset(&embedding_object, "embeddingObject", inferred_source, name)?;
check_unset(&input_type, "inputType", inferred_source, name)?;

if let Setting::Set(model) = &model {
let model = crate::vector::openai::EmbeddingModel::from_name(model.as_str())
.ok_or(crate::error::UserError::InvalidOpenAiModel {
Expand Down Expand Up @@ -1224,17 +1270,43 @@ pub fn validate_embedding_settings(
check_set(&model, "model", inferred_source, name)?;
check_unset(&api_key, "apiKey", inferred_source, name)?;
check_unset(&revision, "revision", inferred_source, name)?;

check_unset(&url, "url", inferred_source, name)?;
check_unset(&query, "query", inferred_source, name)?;
check_unset(&input_field, "inputField", inferred_source, name)?;
check_unset(&path_to_embeddings, "pathToEmbeddings", inferred_source, name)?;
check_unset(&embedding_object, "embeddingObject", inferred_source, name)?;
check_unset(&input_type, "inputType", inferred_source, name)?;
}
EmbedderSource::HuggingFace => {
check_unset(&api_key, "apiKey", inferred_source, name)?;
check_unset(&dimensions, "dimensions", inferred_source, name)?;

check_unset(&url, "url", inferred_source, name)?;
check_unset(&query, "query", inferred_source, name)?;
check_unset(&input_field, "inputField", inferred_source, name)?;
check_unset(&path_to_embeddings, "pathToEmbeddings", inferred_source, name)?;
check_unset(&embedding_object, "embeddingObject", inferred_source, name)?;
check_unset(&input_type, "inputType", inferred_source, name)?;
}
EmbedderSource::UserProvided => {
check_unset(&model, "model", inferred_source, name)?;
check_unset(&revision, "revision", inferred_source, name)?;
check_unset(&api_key, "apiKey", inferred_source, name)?;
check_unset(&document_template, "documentTemplate", inferred_source, name)?;
check_set(&dimensions, "dimensions", inferred_source, name)?;

check_unset(&url, "url", inferred_source, name)?;
check_unset(&query, "query", inferred_source, name)?;
check_unset(&input_field, "inputField", inferred_source, name)?;
check_unset(&path_to_embeddings, "pathToEmbeddings", inferred_source, name)?;
check_unset(&embedding_object, "embeddingObject", inferred_source, name)?;
check_unset(&input_type, "inputType", inferred_source, name)?;
}
EmbedderSource::Rest => {
check_unset(&model, "model", inferred_source, name)?;
check_unset(&revision, "revision", inferred_source, name)?;
check_set(&url, "url", inferred_source, name)?;
}
}
Ok(Setting::Set(EmbeddingSettings {
Expand All @@ -1244,6 +1316,12 @@ pub fn validate_embedding_settings(
api_key,
dimensions,
document_template,
url,
query,
input_field,
path_to_embeddings,
embedding_object,
input_type,
}))
}

Expand Down

0 comments on commit 34dfea7

Please sign in to comment.