Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support camembert #42

Merged
merged 1 commit into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ such as:

### Supported Models

You can use any BERT or XLM-RoBERTa model with absolute positions in `text-embeddings-inference`.
You can use any BERT, CamemBERT or XLM-RoBERTa model with absolute positions in `text-embeddings-inference`.

**Support for other model types will be added in the future.**

Expand Down
1 change: 1 addition & 0 deletions backends/candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ impl CandleBackend {
// Check model type
if config.model_type != Some("bert".to_string())
&& config.model_type != Some("xlm-roberta".to_string())
&& config.model_type != Some("camembert".to_string())
{
return Err(BackendError::Start(format!(
"Model {:?} is not supported",
Expand Down
13 changes: 7 additions & 6 deletions router/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,12 +214,13 @@ async fn main() -> Result<()> {
);
tokenizer.with_padding(None);

// Position IDs offset. Used for Roberta.
let position_offset = if &config.model_type == "xlm-roberta" {
config.pad_token_id + 1
} else {
0
};
// Position IDs offset. Used for Roberta and camembert.
let position_offset =
if &config.model_type == "xlm-roberta" || &config.model_type == "camembert" {
config.pad_token_id + 1
} else {
0
};
let max_input_length = config.max_position_embeddings - position_offset;

let tokenization_workers = args
Expand Down