From d1c97cdcc764fa10027d748da60c5ab66fea959d Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 25 Oct 2023 18:08:50 +0200 Subject: [PATCH] feat: support camembert --- README.md | 2 +- backends/candle/src/lib.rs | 1 + router/src/main.rs | 13 +++++++------ 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 265285bc..b0e84029 100644 --- a/README.md +++ b/README.md @@ -53,7 +53,7 @@ such as: ### Supported Models -You can use any BERT or XLM-RoBERTa model with absolute positions in `text-embeddings-inference`. +You can use any BERT, CamemBERT or XLM-RoBERTa model with absolute positions in `text-embeddings-inference`. **Support for other model types will be added in the future.** diff --git a/backends/candle/src/lib.rs b/backends/candle/src/lib.rs index 6525c74e..897cd380 100644 --- a/backends/candle/src/lib.rs +++ b/backends/candle/src/lib.rs @@ -34,6 +34,7 @@ impl CandleBackend { // Check model type if config.model_type != Some("bert".to_string()) && config.model_type != Some("xlm-roberta".to_string()) + && config.model_type != Some("camembert".to_string()) { return Err(BackendError::Start(format!( "Model {:?} is not supported", diff --git a/router/src/main.rs b/router/src/main.rs index 06683116..ade6adf2 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -214,12 +214,13 @@ async fn main() -> Result<()> { ); tokenizer.with_padding(None); - // Position IDs offset. Used for Roberta. - let position_offset = if &config.model_type == "xlm-roberta" { - config.pad_token_id + 1 - } else { - 0 - }; + // Position IDs offset. Used for Roberta and camembert. + let position_offset = + if &config.model_type == "xlm-roberta" || &config.model_type == "camembert" { + config.pad_token_id + 1 + } else { + 0 + }; let max_input_length = config.max_position_embeddings - position_offset; let tokenization_workers = args