Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Splade #1747

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 70 additions & 2 deletions candle-transformers/src/models/bert.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::with_tracing::{layer_norm, linear, LayerNorm, Linear};
use candle::{DType, Device, Result, Tensor};
use candle_nn::{embedding, Embedding, Module, VarBuilder};
use candle_nn::{embedding, init::DEFAULT_KAIMING_NORMAL, Embedding, Init, Module, VarBuilder};
use serde::Deserialize;

pub const DTYPE: DType = DType::F32;
Expand All @@ -13,7 +13,7 @@ pub enum HiddenAct {
Relu,
}

struct HiddenActLayer {
pub struct HiddenActLayer {
act: HiddenAct,
span: tracing::Span,
}
Expand Down Expand Up @@ -488,3 +488,71 @@ impl BertModel {
Ok(sequence_output)
}
}

struct BertPredictionHeadTransform {
dense: Linear,
act: HiddenActLayer,
layernorm: LayerNorm,
span: tracing::Span,
}

impl BertPredictionHeadTransform {
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let hidden_size = config.hidden_size;
let layer_norm_eps = config.layer_norm_eps;

let dense = linear(hidden_size, hidden_size, vb.pp("dense"))?;
let act = HiddenActLayer::new(HiddenAct::GeluApproximate);
let layernorm = layer_norm(hidden_size, layer_norm_eps, vb.pp("LayerNorm"))?;
Ok(Self {
dense,
act,
layernorm,
span: tracing::span!(tracing::Level::TRACE, "self-bert-pred-head-trans"),
})
}
pub fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let hidden_states = self.dense.forward(hidden_states)?;
let hidden_states = self.act.forward(&hidden_states)?;
let layer_output = self.layernorm.forward(&hidden_states)?;
Ok(layer_output)
}
}

pub struct BertPredictionHead {
transform: BertPredictionHeadTransform,
decoder: Linear,
span: tracing::Span,
}

pub fn linear_decoder(in_dim: usize, out_dim: usize, vs: VarBuilder) -> Result<Linear> {
let ws = vs.get((out_dim, in_dim), "bert.embeddings.word_embeddings.weight")?;
let bs = vs.get(out_dim, "cls.predictions.decoder.bias")?;

Ok(Linear::from_weights(ws, Some(bs)))
}

impl BertPredictionHead {
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let hidden_size = config.hidden_size;
let vocab_size = config.vocab_size;
let transform =
BertPredictionHeadTransform::load(vb.pp("cls.predictions.transform"), config)?;
let decoder = linear_decoder(hidden_size, vocab_size, vb)?;

Ok(Self {
transform,
decoder,
span: tracing::span!(tracing::Level::TRACE, "self-bert-pred-head"),
})
}

pub fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let hidden_states = self.transform.forward(hidden_states)?;
let layer_output = self.decoder.forward(&hidden_states)?;

Ok(layer_output)
}
}
1 change: 1 addition & 0 deletions candle-transformers/src/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ pub mod qwen2;
pub mod repvgg;
pub mod resnet;
pub mod segment_anything;
pub mod splade;
pub mod stable_diffusion;
pub mod stable_lm;
pub mod t5;
Expand Down
67 changes: 67 additions & 0 deletions candle-transformers/src/models/splade.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
use std::collections::HashMap;

use super::bert::{BertModel, BertPredictionHead, Config};
use candle_nn::VarBuilder;

use candle::{Device, Result, Tensor};

struct SpladePredictionHead {
head: BertPredictionHead,
}

pub struct SpladeModel {
bert: BertModel,
head: BertPredictionHead,
pub device: Device,
span: tracing::Span,
}

impl SpladeModel {
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let bert = BertModel::load(vb.pp("bert"), config)?;
let head = BertPredictionHead::load(vb.clone(), config)?;

Ok(Self {
bert,
head,
device: vb.device().clone(),
span: tracing::span!(tracing::Level::TRACE, "model"),
})
}

pub fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let bert_output = self.bert.forward(input_ids, token_type_ids)?;

let predictions = self.head.forward(&bert_output)?;
let beans = (1. + predictions.relu()?)?.log()?.sum(1)?;
Ok(beans)
}

pub fn sparse_forward(
&self,
input_ids: &Tensor,
token_type_ids: &Tensor,
filter: Option<f32>,
) -> Result<HashMap<u32, f32>> {
let dense_terms = self.forward(input_ids, token_type_ids)?;

let filter = filter.unwrap_or(1.0);

let value_map: HashMap<u32, f32> = dense_terms
.flatten_all()
.unwrap()
.to_vec1::<f32>()
.unwrap()
.into_iter()
.enumerate()
.filter(|&(idx, v)| v > filter && (1996 <= idx) && (idx < 29612))
.map(|(idx, v)| {
// Ensure idx (of type usize) is within u32 range
let idx_u32 = idx as u32; // Directly cast usize to u32
(idx_u32, v)
})
.collect();
Ok(value_map)
}
}