Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: splade pooling #174

Merged
merged 6 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,16 @@ Options:
--pooling <POOLING>
Optionally control the pooling method for embedding models.

If `pooling` is not set, the pooling configuration will be parsed from the model `1_Pooling/config.json`
configuration.
If `pooling` is not set, the pooling configuration will be parsed from the model `1_Pooling/config.json` configuration.

If `pooling` is set, it will override the model pooling configuration

[env: POOLING=]
[possible values: cls, mean]

Possible values:
- cls: Select the CLS token as embedding
- mean: Apply Mean pooling to the model embeddings
- splade: Apply SPLADE (Sparse Lexical and Expansion) to the model embeddings. This option is only available if the loaded model is a `ForMaskedLM` Transformer model

--max-concurrent-requests <MAX_CONCURRENT_REQUESTS>
The maximum amount of concurrent requests for this particular deployment.
Expand Down
31 changes: 20 additions & 11 deletions backends/candle/src/layers/layer_norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,15 @@ impl LayerNorm {
})
}

pub fn forward(&self, hidden_states: &Tensor, residual: &Tensor) -> Result<Tensor> {
pub fn forward(&self, hidden_states: &Tensor, residual: Option<&Tensor>) -> Result<Tensor> {
let _enter = self.span.enter();

match hidden_states.device() {
Device::Cpu | Device::Metal(_) => {
let hidden_states = hidden_states.add(residual)?;
let mut hidden_states = hidden_states.clone();
if let Some(residual) = residual {
hidden_states = hidden_states.add(residual)?;
}
let hidden_states_dtype = hidden_states.dtype();
let internal_dtype = match hidden_states_dtype {
DType::F16 | DType::BF16 => DType::F32,
Expand All @@ -51,19 +54,25 @@ impl LayerNorm {
Device::Cuda(_) => {
#[cfg(feature = "cuda")]
{
use candle_layer_norm::fused_add_layer_norm;
use candle_layer_norm::{fused_add_layer_norm, layer_norm};

let original_shape = hidden_states.shape();
let hidden_states = hidden_states.flatten_to(D::Minus2)?;
let residual = residual.flatten_to(D::Minus2)?;

let (result, _) = fused_add_layer_norm(
&hidden_states,
&residual,
&self.weight,
Some(&self.bias),
self.epsilon,
)?;
let result = if let Some(residual) = residual {
let residual = residual.flatten_to(D::Minus2)?;

let (result, _) = fused_add_layer_norm(
&hidden_states,
&residual,
&self.weight,
Some(&self.bias),
self.epsilon,
)?;
Ok(result)
} else {
layer_norm(&hidden_states, &self.weight, Some(&self.bias), self.epsilon)
}?;
result.reshape(original_shape)
}
#[cfg(not(feature = "cuda"))]
Expand Down
35 changes: 33 additions & 2 deletions backends/candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@ use crate::compute_cap::{
get_compile_compute_cap, get_runtime_compute_cap, incompatible_compute_cap,
};
use crate::models::{
BertModel, JinaBertModel, Model, NomicBertModel, NomicConfig, PositionEmbeddingType,
BertModel, DistilBertConfig, DistilBertModel, JinaBertModel, Model, NomicBertModel,
NomicConfig, PositionEmbeddingType,
};
#[cfg(feature = "cuda")]
use crate::models::{FlashBertModel, FlashJinaBertModel, FlashNomicBertModel};
use crate::models::{
FlashBertModel, FlashDistilBertModel, FlashJinaBertModel, FlashNomicBertModel,
};
use candle::{DType, Device};
use candle_nn::VarBuilder;
use models::BertConfig;
Expand All @@ -33,6 +36,8 @@ enum Config {
XlmRoberta(BertConfig),
Camembert(BertConfig),
Roberta(BertConfig),
#[serde(rename(deserialize = "distilbert"))]
DistilBert(DistilBertConfig),
#[serde(rename(deserialize = "nomic_bert"))]
NomicBert(NomicConfig),
}
Expand Down Expand Up @@ -119,6 +124,12 @@ impl CandleBackend {
BertModel::load_roberta(vb, &config, model_type).s()?,
))
}
(Config::DistilBert(config), Device::Cpu | Device::Metal(_)) => {
tracing::info!("Starting DistilBertModel model on {:?}", device);
Ok(Box::new(
DistilBertModel::load(vb, &config, model_type).s()?,
))
}
(Config::NomicBert(config), Device::Cpu | Device::Metal(_)) => {
tracing::info!("Starting NomicBertModel model on {:?}", device);
Ok(Box::new(NomicBertModel::load(vb, &config, model_type).s()?))
Expand Down Expand Up @@ -175,6 +186,26 @@ impl CandleBackend {
}
}
#[cfg(feature = "cuda")]
(Config::DistilBert(config), Device::Cuda(_)) => {
if cfg!(feature = "flash-attn")
&& dtype == DType::F16
&& &std::env::var("USE_FLASH_ATTENTION")
.unwrap_or("True".to_string())
.to_lowercase()
== "true"
{
tracing::info!("Starting FlashNomicBertModel model on {:?}", device);
Ok(Box::new(
FlashDistilBertModel::load(vb, &config, model_type).s()?,
))
} else {
tracing::info!("Starting DistilBertModel model on {:?}", device);
Ok(Box::new(
DistilBertModel::load(vb, &config, model_type).s()?,
))
}
}
#[cfg(feature = "cuda")]
(Config::NomicBert(config), Device::Cuda(_)) => {
if cfg!(feature = "flash-attn")
&& dtype == DType::F16
Expand Down
12 changes: 10 additions & 2 deletions backends/candle/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,25 @@ extern crate intel_mkl_src;
extern crate accelerate_src;

mod bert;
mod distilbert;
mod jina;
mod nomic;

#[cfg(feature = "cuda")]
mod flash_bert;

#[cfg(feature = "cuda")]
mod flash_jina;
mod jina;

#[cfg(feature = "cuda")]
mod flash_nomic;
mod nomic;

#[cfg(feature = "cuda")]
mod flash_distilbert;

pub use bert::{BertConfig, BertModel, PositionEmbeddingType};
use candle::{Result, Tensor};
pub use distilbert::{DistilBertConfig, DistilBertModel};
pub use jina::JinaBertModel;
pub use nomic::{NomicBertModel, NomicConfig};
use text_embeddings_backend_core::Batch;
Expand All @@ -32,6 +37,9 @@ pub use flash_jina::FlashJinaBertModel;
#[cfg(feature = "cuda")]
pub use flash_nomic::FlashNomicBertModel;

#[cfg(feature = "cuda")]
pub use flash_distilbert::FlashDistilBertModel;

pub(crate) trait Model {
fn is_padded(&self) -> bool;

Expand Down
21 changes: 14 additions & 7 deletions backends/candle/src/models/bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ pub struct BertConfig {
#[serde(default)]
pub use_cache: bool,
pub classifier_dropout: Option<f64>,
pub model_type: Option<String>,
pub id2label: Option<HashMap<String, String>>,
}

Expand All @@ -39,7 +38,7 @@ pub enum PositionEmbeddingType {
}

#[derive(Debug)]
struct BertEmbeddings {
pub struct BertEmbeddings {
word_embeddings: Embedding,
token_type_embeddings: Embedding,
position_embeddings: Embedding,
Expand Down Expand Up @@ -80,7 +79,7 @@ impl BertEmbeddings {
})
}

fn forward(
pub fn forward(
&self,
input_ids: &Tensor,
token_type_ids: &Tensor,
Expand All @@ -93,7 +92,9 @@ impl BertEmbeddings {
let position_embeddings = self.position_embeddings.forward(position_ids)?;

let embeddings = input_embeddings.add(&token_type_embeddings)?;
let embeddings = self.layer_norm.forward(&embeddings, &position_embeddings)?;
let embeddings = self
.layer_norm
.forward(&embeddings, Some(&position_embeddings))?;

Ok(embeddings)
}
Expand Down Expand Up @@ -255,7 +256,7 @@ impl BertAttention {
let context_layer = context_layer.transpose(1, 2)?.flatten_from(D::Minus2)?;

let hidden_states = self.dense.forward(&context_layer)?;
let hidden_states = self.layer_norm.forward(&hidden_states, &residual)?;
let hidden_states = self.layer_norm.forward(&hidden_states, Some(&residual))?;

Ok(hidden_states)
}
Expand Down Expand Up @@ -324,7 +325,7 @@ impl BertLayer {

let hidden_states = self.intermediate.forward(&hidden_states)?;
let hidden_states = self.output.forward(&hidden_states)?;
let hidden_states = self.layer_norm.forward(&hidden_states, &residual)?;
let hidden_states = self.layer_norm.forward(&hidden_states, Some(&residual))?;

Ok(hidden_states)
}
Expand Down Expand Up @@ -469,7 +470,12 @@ impl BertModel {
Box::new(BertClassificationHead::load(vb.pp("classifier"), config)?);
(pool, Some(classifier))
}
ModelType::Embedding(pool) => (pool, None),
ModelType::Embedding(pool) => {
if pool == Pool::Splade {
candle::bail!("`splade` is not supported for Nomic")
}
(pool, None)
}
};

let (embeddings, encoder) = match (
Expand Down Expand Up @@ -724,6 +730,7 @@ impl BertModel {

(outputs.sum(1)?.broadcast_div(&input_lengths))?
}
Pool::Splade => unreachable!(),
};
Some(pooled_embeddings)
} else {
Expand Down
Loading
Loading