Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Ollama as an embeddings provider #4456

Merged
merged 3 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions meilisearch/src/routes/indexes/settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,7 @@ fn embedder_analytics(
EmbedderSource::OpenAi => sources.insert("openAi"),
EmbedderSource::HuggingFace => sources.insert("huggingFace"),
EmbedderSource::UserProvided => sources.insert("userProvided"),
EmbedderSource::Ollama => sources.insert("ollama"),
};
}
};
Expand Down
7 changes: 7 additions & 0 deletions milli/src/update/settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1178,6 +1178,13 @@ pub fn validate_embedding_settings(
}
}
}
EmbedderSource::Ollama => {
// Dimensions get inferred, only model name is required
check_unset(&dimensions, "dimensions", inferred_source, name)?;
check_set(&model, "model", inferred_source, name)?;
check_unset(&api_key, "apiKey", inferred_source, name)?;
check_unset(&revision, "revision", inferred_source, name)?;
}
EmbedderSource::HuggingFace => {
check_unset(&api_key, "apiKey", inferred_source, name)?;
check_unset(&dimensions, "dimensions", inferred_source, name)?;
Expand Down
39 changes: 39 additions & 0 deletions milli/src/vector/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::path::PathBuf;

use hf_hub::api::sync::ApiError;

use super::ollama::OllamaError;
use crate::error::FaultSource;
use crate::vector::openai::OpenAiError;

Expand Down Expand Up @@ -71,6 +72,17 @@ pub enum EmbedErrorKind {
OpenAiRuntimeInit(std::io::Error),
#[error("initializing web client for sending embedding requests failed: {0}")]
InitWebClient(reqwest::Error),
// Dedicated Ollama error kinds, might have to merge them into one cohesive error type for all backends.
#[error("unexpected response from Ollama: {0}")]
OllamaUnexpected(reqwest::Error),
#[error("sent too many requests to Ollama: {0}")]
OllamaTooManyRequests(OllamaError),
#[error("received internal error from Ollama: {0}")]
OllamaInternalServerError(OllamaError),
#[error("model not found. Meilisearch will not automatically download models from the Ollama library, please pull the model manually: {0}")]
OllamaModelNotFoundError(OllamaError),
#[error("received unhandled HTTP status code {0} from Ollama")]
OllamaUnhandledStatusCode(u16),
}

impl EmbedError {
Expand Down Expand Up @@ -129,6 +141,26 @@ impl EmbedError {
pub fn openai_initialize_web_client(inner: reqwest::Error) -> Self {
Self { kind: EmbedErrorKind::InitWebClient(inner), fault: FaultSource::Runtime }
}

pub(crate) fn ollama_unexpected(inner: reqwest::Error) -> EmbedError {
Self { kind: EmbedErrorKind::OllamaUnexpected(inner), fault: FaultSource::Bug }
}

pub(crate) fn ollama_model_not_found(inner: OllamaError) -> EmbedError {
Self { kind: EmbedErrorKind::OllamaModelNotFoundError(inner), fault: FaultSource::User }
}

pub(crate) fn ollama_too_many_requests(inner: OllamaError) -> EmbedError {
Self { kind: EmbedErrorKind::OllamaTooManyRequests(inner), fault: FaultSource::Runtime }
}

pub(crate) fn ollama_internal_server_error(inner: OllamaError) -> EmbedError {
Self { kind: EmbedErrorKind::OllamaInternalServerError(inner), fault: FaultSource::Runtime }
}

pub(crate) fn ollama_unhandled_status_code(code: u16) -> EmbedError {
Self { kind: EmbedErrorKind::OllamaUnhandledStatusCode(code), fault: FaultSource::Bug }
}
}

#[derive(Debug, thiserror::Error)]
Expand Down Expand Up @@ -195,6 +227,13 @@ impl NewEmbedderError {
}
}

pub fn ollama_could_not_determine_dimension(inner: EmbedError) -> NewEmbedderError {
Self {
kind: NewEmbedderErrorKind::CouldNotDetermineDimension(inner),
fault: FaultSource::User,
}
}

pub fn openai_invalid_api_key_format(inner: reqwest::header::InvalidHeaderValue) -> Self {
Self { kind: NewEmbedderErrorKind::InvalidApiKeyFormat(inner), fault: FaultSource::User }
}
Expand Down
18 changes: 18 additions & 0 deletions milli/src/vector/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ pub mod manual;
pub mod openai;
pub mod settings;

pub mod ollama;

pub use self::error::Error;

pub type Embedding = Vec<f32>;
Expand Down Expand Up @@ -76,6 +78,7 @@ pub enum Embedder {
HuggingFace(hf::Embedder),
OpenAi(openai::Embedder),
UserProvided(manual::Embedder),
Ollama(ollama::Embedder),
}

#[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)]
Expand Down Expand Up @@ -127,6 +130,7 @@ impl IntoIterator for EmbeddingConfigs {
pub enum EmbedderOptions {
HuggingFace(hf::EmbedderOptions),
OpenAi(openai::EmbedderOptions),
Ollama(ollama::EmbedderOptions),
UserProvided(manual::EmbedderOptions),
}

Expand All @@ -144,13 +148,18 @@ impl EmbedderOptions {
pub fn openai(api_key: Option<String>) -> Self {
Self::OpenAi(openai::EmbedderOptions::with_default_model(api_key))
}

pub fn ollama() -> Self {
Self::Ollama(ollama::EmbedderOptions::with_default_model())
}
}

impl Embedder {
pub fn new(options: EmbedderOptions) -> std::result::Result<Self, NewEmbedderError> {
Ok(match options {
EmbedderOptions::HuggingFace(options) => Self::HuggingFace(hf::Embedder::new(options)?),
EmbedderOptions::OpenAi(options) => Self::OpenAi(openai::Embedder::new(options)?),
EmbedderOptions::Ollama(options) => Self::Ollama(ollama::Embedder::new(options)?),
EmbedderOptions::UserProvided(options) => {
Self::UserProvided(manual::Embedder::new(options))
}
Expand All @@ -167,6 +176,10 @@ impl Embedder {
let client = embedder.new_client()?;
embedder.embed(texts, &client).await
}
Embedder::Ollama(embedder) => {
let client = embedder.new_client()?;
embedder.embed(texts, &client).await
}
Embedder::UserProvided(embedder) => embedder.embed(texts),
}
}
Expand All @@ -181,6 +194,7 @@ impl Embedder {
match self {
Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks),
Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks),
Embedder::Ollama(embedder) => embedder.embed_chunks(text_chunks),
Embedder::UserProvided(embedder) => embedder.embed_chunks(text_chunks),
}
}
Expand All @@ -189,6 +203,7 @@ impl Embedder {
match self {
Embedder::HuggingFace(embedder) => embedder.chunk_count_hint(),
Embedder::OpenAi(embedder) => embedder.chunk_count_hint(),
Embedder::Ollama(embedder) => embedder.chunk_count_hint(),
Embedder::UserProvided(_) => 1,
}
}
Expand All @@ -197,6 +212,7 @@ impl Embedder {
match self {
Embedder::HuggingFace(embedder) => embedder.prompt_count_in_chunk_hint(),
Embedder::OpenAi(embedder) => embedder.prompt_count_in_chunk_hint(),
Embedder::Ollama(embedder) => embedder.prompt_count_in_chunk_hint(),
Embedder::UserProvided(_) => 1,
}
}
Expand All @@ -205,6 +221,7 @@ impl Embedder {
match self {
Embedder::HuggingFace(embedder) => embedder.dimensions(),
Embedder::OpenAi(embedder) => embedder.dimensions(),
Embedder::Ollama(embedder) => embedder.dimensions(),
Embedder::UserProvided(embedder) => embedder.dimensions(),
}
}
Expand All @@ -213,6 +230,7 @@ impl Embedder {
match self {
Embedder::HuggingFace(embedder) => embedder.distribution(),
Embedder::OpenAi(embedder) => embedder.distribution(),
Embedder::Ollama(embedder) => embedder.distribution(),
Embedder::UserProvided(_embedder) => None,
}
}
Expand Down
Loading
Loading