Skip to content

Commit

Permalink
feat: add normalize option (#70)
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivierDehaene committed Nov 13, 2023
1 parent 618076e commit 2b4b5d2
Show file tree
Hide file tree
Showing 18 changed files with 268 additions and 261 deletions.
388 changes: 204 additions & 184 deletions Cargo.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ authors = ["Olivier Dehaene"]
homepage = "https://github.com/huggingface/text-embeddings-inference"

[patch.crates-io]
cudarc = { git = "https://github.com/OlivierDehaene/cudarc", rev = "8be6ff46e4a2014fb563570e0d206c09aea88152" }
cudarc = { git = "https://github.com/OlivierDehaene/cudarc", rev = "c19522f1e411ab453d71bdfad3383b118cd4216f" }
candle = { git = "https://github.com/OlivierDehaene/candle", rev = "9f2b4081b83a0e47ec1b12caa71d3cac7cc2161e", package = "candle-core" }
candle-nn = { git = "https://github.com/OlivierDehaene/candle", rev = "9f2b4081b83a0e47ec1b12caa71d3cac7cc2161e", package = "candle-nn" }
candle-transformers = { git = "https://github.com/OlivierDehaene/candle", rev = "9f2b4081b83a0e47ec1b12caa71d3cac7cc2161e", package = "candle-transformers" }
Expand Down
2 changes: 1 addition & 1 deletion backends/candle/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ candle-nn = { version = "0.3.0" }
candle-transformers = { version = "0.3.0" }
candle-flash-attn = { version = "0.3.0", optional = true }
candle-flash-attn-v1 = { git = "https://github.com/huggingface/candle-flash-attn-v1", rev = "62b75f1ea4e0961fad7b983ee8d723ed6fd68be5", optional = true }
candle-cublaslt = { git = "https://github.com/huggingface/candle-cublaslt", rev = "ffd246552c266640fab217f964a83960e07a66ec", optional = true }
candle-cublaslt = { git = "https://github.com/huggingface/candle-cublaslt", rev = "58684e116aae248c353f87846ddf0b2a8a7ed855", optional = true }
candle-layer-norm = { git = "https://github.com/huggingface/candle-layer-norm", rev = "5ed96012a693dff9685320765dd55a57fdaecdd6", optional = true }
lazy_static = "^1.4"
text-embeddings-backend-core = { path = "../core" }
Expand Down
5 changes: 1 addition & 4 deletions backends/candle/src/models/bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -554,10 +554,7 @@ impl BertModel {
}
};

// Normalize
let normalized_results = results.broadcast_div(&results.sqr()?.sum_keepdim(1)?.sqrt()?)?;

Ok(normalized_results)
Ok(results)
}
}

Expand Down
5 changes: 1 addition & 4 deletions backends/candle/src/models/bert_quant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -448,10 +448,7 @@ impl QuantBertModel {
Pool::Mean => (outputs.sum_keepdim(0)? / (batch.max_length as f64))?,
};

// Normalize
let normalized_results = results.broadcast_div(&results.sqr()?.sum_keepdim(1)?.sqrt()?)?;

Ok(normalized_results)
Ok(results)
}
}

Expand Down
5 changes: 1 addition & 4 deletions backends/candle/src/models/flash_bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -383,10 +383,7 @@ impl FlashBertModel {
}
};

// Normalize
let normalized_results = results.broadcast_div(&results.sqr()?.sum_keepdim(1)?.sqrt()?)?;

Ok(normalized_results)
Ok(results)
}
}

Expand Down
5 changes: 1 addition & 4 deletions backends/candle/src/models/jina.rs
Original file line number Diff line number Diff line change
Expand Up @@ -582,10 +582,7 @@ impl JinaBertModel {
}
};

// Normalize
let normalized_results = results.broadcast_div(&results.sqr()?.sum_keepdim(1)?.sqrt()?)?;

Ok(normalized_results)
Ok(results)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,7 @@ def embed(self, batch: PaddedBatch) -> List[Embedding]:

output = self.model(**kwargs)
embedding = output[0][:, 0]
results = torch.nn.functional.normalize(embedding, p=2, dim=1)

cpu_results = results.view(-1).tolist()
cpu_results = embedding.view(-1).tolist()

return [
Embedding(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,7 @@ def embed(self, batch: FlashBatch) -> List[Embedding]:
cu_seqlens=batch.cu_seqlens,
max_s=batch.max_s,
)
results = torch.nn.functional.normalize(embedding, p=2, dim=1)
cpu_results = results.view(-1).tolist()
cpu_results = embedding.view(-1).tolist()

return [
Embedding(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ def _start_span(self, handler_call_details, context, set_status_on_exception=Fal


def setup_tracing(otlp_endpoint: str):
resource = Resource.create(attributes={"service.name": f"text-embeddings-inference.server"})
resource = Resource.create(
attributes={"service.name": f"text-embeddings-inference.server"}
)
span_exporter = OTLPSpanExporter(endpoint=otlp_endpoint, insecure=True)
span_processor = BatchSpanProcessor(span_exporter)

Expand Down
3 changes: 2 additions & 1 deletion core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ homepage.workspace = true
[dependencies]
hf-hub = { version = "^0.3.0", features = ["tokio"] }
metrics = "^0.21"
rayon = "^1.8"
text-embeddings-backend = { path = "../backends" }
thiserror = "^1.0"
tokenizers = { version = "^0.14", default-features=false, features=["onig"] }
tokenizers = { version = "^0.14.1", default-features=false, features=["onig", "esaxx_fast"] }
tracing = "^0.1"
tokio = { version = "^1.25", features = ["rt", "rt-multi-thread", "parking_lot", "sync"] }
21 changes: 20 additions & 1 deletion core/src/infer.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::queue::{Entry, Metadata, NextBatch, Queue};
use crate::tokenization::Tokenization;
use crate::TextEmbeddingsError;
use rayon::prelude::*;
use std::sync::Arc;
use std::time::{Duration, Instant};
use text_embeddings_backend::{Backend, Embedding};
Expand Down Expand Up @@ -85,6 +86,7 @@ impl Infer {
&self,
inputs: String,
truncate: bool,
normalize: bool,
permit: OwnedSemaphorePermit,
) -> Result<InferResponse, TextEmbeddingsError> {
let start_time = Instant::now();
Expand Down Expand Up @@ -112,6 +114,7 @@ impl Infer {
tokenization: start_time.elapsed(),
queue_time: Instant::now(),
prompt_tokens: encoding.input_ids.len(),
normalize,
},
encoding,
});
Expand Down Expand Up @@ -185,7 +188,23 @@ async fn embed_task(
// Handle sending responses in another thread to avoid starving the backend
tokio::task::spawn_blocking(move || match results {
Ok(embeddings) => {
batch.0.into_iter().zip(embeddings).for_each(|(m, e)| {
batch.0.into_par_iter().zip(embeddings).for_each(|(m, e)| {
let e = match m.normalize {
// Normalize embedding
true => {
let scale = (1.0
/ e.iter()
.map(|v| {
let v = *v as f64;
v * v
})
.sum::<f64>()
.sqrt()) as f32;
e.into_iter().map(|v| v * scale).collect()
}
false => e,
};

let _ = m.response_tx.send(Ok(InferResponse {
embeddings: e,
prompt_tokens: m.prompt_tokens,
Expand Down
49 changes: 11 additions & 38 deletions core/src/queue.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
use crate::infer::InferResponse;
use crate::tokenization::Encoding;
use std::alloc::{alloc, Layout};
use std::cmp::max;
use std::collections::VecDeque;
use std::ptr;
use std::time::{Duration, Instant};
use text_embeddings_backend::{BackendError, Batch};
use tokio::sync::{mpsc, oneshot};
Expand Down Expand Up @@ -31,6 +29,8 @@ pub struct Metadata {
pub queue_time: Instant,
/// Number of tokens in the prompt
pub prompt_tokens: usize,
/// Normalize the embeddings
pub normalize: bool,
}

/// Request Queue
Expand Down Expand Up @@ -114,13 +114,12 @@ fn queue_blocking_task(
QueueCommand::NextBatch {
response_sender,
span,
} => unsafe {
} => {
let _span = span.entered();

// Allocate raw memory
let raw_input_ids = raw_u32_vec(max_batch_tokens);
let raw_token_type_ids = raw_u32_vec(max_batch_tokens);
let raw_position_ids = raw_u32_vec(max_batch_tokens);
let mut input_ids = Vec::with_capacity(max_batch_tokens);
let mut token_type_ids = Vec::with_capacity(max_batch_tokens);
let mut position_ids = Vec::with_capacity(max_batch_tokens);

let mut metadata = Vec::with_capacity(capacity);
let mut cu_seq_lengths = Vec::with_capacity(capacity);
Expand All @@ -129,7 +128,7 @@ fn queue_blocking_task(
let mut current_tokens = 0;
let mut max_length = 0;

while let Some(mut entry) = entries.pop_front() {
while let Some(entry) = entries.pop_front() {
// Filter entries where the response receiver was dropped (== entries where the request
// was dropped by the client)
if entry.metadata.response_tx.is_closed() {
Expand All @@ -146,22 +145,9 @@ fn queue_blocking_task(

max_length = max(max_length, entry_tokens as u32);

// Copy memory to the correct spot in the raw vectors
ptr::copy(
entry.encoding.input_ids.as_mut_ptr(),
raw_input_ids.add(current_tokens),
entry.encoding.input_ids.len(),
);
ptr::copy(
entry.encoding.token_type_ids.as_mut_ptr(),
raw_token_type_ids.add(current_tokens),
entry.encoding.token_type_ids.len(),
);
ptr::copy(
entry.encoding.position_ids.as_mut_ptr(),
raw_position_ids.add(current_tokens),
entry.encoding.position_ids.len(),
);
input_ids.extend(entry.encoding.input_ids);
token_type_ids.extend(entry.encoding.token_type_ids);
position_ids.extend(entry.encoding.position_ids);

current_tokens += entry_tokens;
metadata.push(entry.metadata);
Expand All @@ -172,14 +158,6 @@ fn queue_blocking_task(
}
}

// Create final vectors from raw memory
let input_ids =
Vec::from_raw_parts(raw_input_ids, current_tokens, max_batch_tokens);
let token_type_ids =
Vec::from_raw_parts(raw_token_type_ids, current_tokens, max_batch_tokens);
let position_ids =
Vec::from_raw_parts(raw_position_ids, current_tokens, max_batch_tokens);

let batch_size = metadata.len();
let next_batch = if metadata.is_empty() {
None
Expand All @@ -201,16 +179,11 @@ fn queue_blocking_task(
metrics::histogram!("te_batch_next_size", batch_size as f64);
metrics::histogram!("te_batch_next_tokens", current_tokens as f64);
metrics::gauge!("te_queue_size", entries.len() as f64);
},
}
}
}
}

unsafe fn raw_u32_vec(capacity: usize) -> *mut u32 {
let layout = Layout::array::<u32>(capacity).unwrap();
alloc(layout).cast::<u32>()
}

pub type NextBatch = (Vec<Metadata>, Batch);

#[derive(Debug)]
Expand Down
8 changes: 2 additions & 6 deletions core/src/tokenization.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::TextEmbeddingsError;
/// Payload tokenization logic
use crate::TextEmbeddingsError;
use tokenizers::tokenizer::Tokenizer;
use tokenizers::TruncationDirection;
use tokio::sync::{mpsc, oneshot};
Expand Down Expand Up @@ -81,11 +81,7 @@ impl Tokenization {

// Await on response channel
// Unwrap is safe here
let payload = response_receiver.await.expect("Tokenization background task dropped the sender without sending a response. This is a bug.")?;

metrics::histogram!("te_request_input_length", payload.input_ids.len() as f64);

Ok(payload)
Ok(response_receiver.await.expect("Tokenization background task dropped the sender without sending a response. This is a bug.")?)
}
}

Expand Down
2 changes: 2 additions & 0 deletions load_tests/load.js
Original file line number Diff line number Diff line change
Expand Up @@ -54,5 +54,7 @@ export default function () {
tokenizationTIme.add(res.headers["X-Tokenization-Time"]);
queueTime.add(res.headers["X-Queue-Time"]);
inferenceTime.add(res.headers["X-Inference-Time"]);
} else {
console.log(res.error);
}
}
2 changes: 1 addition & 1 deletion router/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ reqwest = { version = "0.11.14", features = [] }
serde = "1.0.152"
serde_json = "1.0.93"
thiserror = "1.0.38"
tokenizers = { version = "0.14.1", default-features=false, features=["onig"] }
tokenizers = { version = "0.14.1", default-features=false, features=["onig", "esaxx_fast"] }
tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tower-http = { version = "0.4.0", features = ["cors"] }
tracing = "0.1.37"
Expand Down
11 changes: 9 additions & 2 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ pub(crate) struct OpenAICompatRequest {
pub(crate) struct OpenAICompatEmbedding {
#[schema(example = "embedding")]
object: &'static str,
#[schema(example = json ! (["0.0", "1.0", "2.0"]))]
#[schema(example = json!(["0.0", "1.0", "2.0"]))]
embedding: Vec<f32>,
#[schema(example = "0")]
index: usize,
Expand Down Expand Up @@ -89,10 +89,17 @@ pub(crate) struct EmbedRequest {
#[serde(default)]
#[schema(default = "false", example = "false")]
pub truncate: bool,
#[serde(default = "default_normalize")]
#[schema(default = "true", example = "true")]
pub normalize: bool,
}

fn default_normalize() -> bool {
true
}

#[derive(Serialize, ToSchema)]
#[schema(example = json ! ([["0.0", "1.0", "2.0"]]))]
#[schema(example = json!([["0.0", "1.0", "2.0"]]))]
pub(crate) struct EmbedResponse(Vec<Vec<f32>>);

#[derive(Serialize, ToSchema)]
Expand Down
10 changes: 6 additions & 4 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ async fn embed(

let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?;
let response = infer
.embed(input, req.truncate, permit)
.embed(input, req.truncate, req.normalize, permit)
.await
.map_err(ErrorResponse::from)?;

Expand Down Expand Up @@ -135,7 +135,9 @@ async fn embed(
let local_infer = infer.clone();
futures.push(async move {
let permit = local_infer.acquire_permit().await;
local_infer.embed(input, req.truncate, permit).await
local_infer
.embed(input, req.truncate, req.normalize, permit)
.await
})
}
let results = join_all(futures)
Expand Down Expand Up @@ -269,7 +271,7 @@ async fn openai_embed(

let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?;
let response = infer
.embed(input, false, permit)
.embed(input, false, true, permit)
.await
.map_err(ErrorResponse::from)?;

Expand Down Expand Up @@ -315,7 +317,7 @@ async fn openai_embed(
let local_infer = infer.clone();
futures.push(async move {
let permit = local_infer.acquire_permit().await;
local_infer.embed(input, false, permit).await
local_infer.embed(input, false, true, permit).await
})
}
let results = join_all(futures)
Expand Down

0 comments on commit 2b4b5d2

Please sign in to comment.