From ae86f71d9019af9d13240bdf82fad79c3c965c6f Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 23 Nov 2023 15:15:32 +0100 Subject: [PATCH 01/12] add watcher refactor to http mod --- backends/src/lib.rs | 44 +++-- core/src/infer.rs | 7 +- router/src/http.rs | 2 + router/src/{ => http}/server.rs | 11 +- router/src/http/types.rs | 338 +++++++++++++++++++++++++++++++ router/src/lib.rs | 341 +------------------------------- 6 files changed, 382 insertions(+), 361 deletions(-) create mode 100644 router/src/http.rs rename router/src/{ => http}/server.rs (99%) create mode 100644 router/src/http/types.rs diff --git a/backends/src/lib.rs b/backends/src/lib.rs index 77da3440..a0321020 100644 --- a/backends/src/lib.rs +++ b/backends/src/lib.rs @@ -1,11 +1,9 @@ mod dtype; use std::path::PathBuf; -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::Arc; use std::time::{Duration, Instant}; use text_embeddings_backend_core::Backend as CoreBackend; -use tokio::sync::{mpsc, oneshot}; +use tokio::sync::{mpsc, oneshot, watch}; use tracing::{instrument, Span}; pub use crate::dtype::DType; @@ -22,7 +20,7 @@ pub struct Backend { /// Channel to communicate with the background thread backend_sender: mpsc::UnboundedSender, /// Health status - health: Arc, + health_receiver: watch::Receiver, pub max_batch_size: Option, pub model_type: ModelType, } @@ -46,11 +44,15 @@ impl Backend { )?; let max_batch_size = backend.max_batch_size(); - tokio::task::spawn_blocking(move || backend_blocking_task(backend, backend_receiver)); + let (health_sender, health_receiver) = watch::channel(false); + + tokio::task::spawn_blocking(move || { + backend_blocking_task(backend, backend_receiver, health_sender) + }); Ok(Self { backend_sender, - health: Arc::new(AtomicBool::new(false)), + health_receiver, max_batch_size, model_type, }) @@ -58,7 +60,7 @@ impl Backend { #[instrument(skip(self))] pub async fn health(&self) -> Result<(), BackendError> { - let result = if self.health.load(Ordering::SeqCst) { + let result = if *self.health_receiver.borrow() { // The backend is healthy. Only do a basic health check by calling the // the underlying health method. @@ -86,11 +88,14 @@ impl Backend { } }; - // Update health - self.health.store(result.is_ok(), Ordering::SeqCst); result } + #[instrument(skip(self))] + pub fn health_watcher(&self) -> watch::Receiver { + self.health_receiver.clone() + } + #[instrument(skip_all)] pub async fn embed(&self, batch: Batch) -> Result<(Vec, Duration), BackendError> { let (sender, receiver) = oneshot::channel(); @@ -102,8 +107,6 @@ impl Backend { "Backend blocking task dropped the sender without send a response. This is a bug.", ); - // Update health - self.health.store(result.is_ok(), Ordering::SeqCst); result } @@ -118,8 +121,6 @@ impl Backend { "Backend blocking task dropped the sender without send a response. This is a bug.", ); - // Update health - self.health.store(result.is_ok(), Ordering::SeqCst); result } } @@ -165,23 +166,32 @@ fn init_backend( fn backend_blocking_task( backend: Box, mut command_receiver: mpsc::UnboundedReceiver, + health_sender: watch::Sender, ) { while let Some(cmd) = command_receiver.blocking_recv() { let start = Instant::now(); + let mut healthy = false; match cmd { BackendCommand::Health(span, sender) => { let _span = span.entered(); - let _ = sender.send(backend.health()); + let _ = sender.send(backend.health().map(|_| healthy = true)); } BackendCommand::Embed(batch, span, sender) => { let _span = span.entered(); - let _ = sender.send(backend.embed(batch).map(|e| (e, start.elapsed()))); + let _ = sender.send(backend.embed(batch).map(|e| { + healthy = true; + (e, start.elapsed()) + })); } BackendCommand::Predict(batch, span, sender) => { let _span = span.entered(); - let _ = sender.send(backend.predict(batch).map(|e| (e, start.elapsed()))); + let _ = sender.send(backend.predict(batch).map(|e| { + healthy = true; + (e, start.elapsed()) + })); } - } + }; + let _ = health_sender.send(healthy); } } diff --git a/core/src/infer.rs b/core/src/infer.rs index 1c347234..d58c03d1 100644 --- a/core/src/infer.rs +++ b/core/src/infer.rs @@ -4,7 +4,7 @@ use crate::TextEmbeddingsError; use std::sync::Arc; use std::time::{Duration, Instant}; use text_embeddings_backend::{Backend, BackendError, ModelType}; -use tokio::sync::{mpsc, oneshot, Notify, OwnedSemaphorePermit, Semaphore}; +use tokio::sync::{mpsc, oneshot, watch, Notify, OwnedSemaphorePermit, Semaphore}; use tracing::{instrument, Span}; /// Inference struct @@ -285,6 +285,11 @@ impl Infer { pub async fn health(&self) -> bool { self.backend.health().await.is_ok() } + + #[instrument(skip(self))] + pub fn health_watcher(&self) -> watch::Receiver { + self.backend.health_watcher() + } } #[instrument(skip_all)] diff --git a/router/src/http.rs b/router/src/http.rs new file mode 100644 index 00000000..3214bbf0 --- /dev/null +++ b/router/src/http.rs @@ -0,0 +1,2 @@ +pub mod server; +mod types; diff --git a/router/src/server.rs b/router/src/http/server.rs similarity index 99% rename from router/src/server.rs rename to router/src/http/server.rs index 49c0c6aa..a7643ced 100644 --- a/router/src/server.rs +++ b/router/src/http/server.rs @@ -1,10 +1,11 @@ /// HTTP Server logic -use crate::{ - ClassifierModel, EmbedRequest, EmbedResponse, EmbeddingModel, ErrorResponse, ErrorType, Info, - Input, ModelType, OpenAICompatEmbedding, OpenAICompatErrorResponse, OpenAICompatRequest, - OpenAICompatResponse, OpenAICompatUsage, PredictInput, PredictRequest, PredictResponse, - Prediction, Rank, RerankRequest, RerankResponse, Sequence, +use crate::http::types::{ + EmbedRequest, EmbedResponse, ErrorResponse, ErrorType, Input, OpenAICompatEmbedding, + OpenAICompatErrorResponse, OpenAICompatRequest, OpenAICompatResponse, OpenAICompatUsage, + PredictInput, PredictRequest, PredictResponse, Prediction, Rank, RerankRequest, RerankResponse, + Sequence, }; +use crate::{ClassifierModel, EmbeddingModel, Info, ModelType}; use axum::extract::Extension; use axum::http::{HeaderMap, Method, StatusCode}; use axum::routing::{get, post}; diff --git a/router/src/http/types.rs b/router/src/http/types.rs new file mode 100644 index 00000000..9d004907 --- /dev/null +++ b/router/src/http/types.rs @@ -0,0 +1,338 @@ +use serde::de::{SeqAccess, Visitor}; +use serde::{de, Deserialize, Deserializer, Serialize}; +use serde_json::json; +use std::fmt::Formatter; +use text_embeddings_core::tokenization::EncodingInput; +use utoipa::openapi::{RefOr, Schema}; +use utoipa::ToSchema; + +#[derive(Debug)] +pub(crate) enum Sequence { + Single(String), + Pair(String, String), +} + +impl Sequence { + pub(crate) fn count_chars(&self) -> usize { + match self { + Sequence::Single(s) => s.chars().count(), + Sequence::Pair(s1, s2) => s1.chars().count() + s2.chars().count(), + } + } +} + +impl From for EncodingInput { + fn from(value: Sequence) -> Self { + match value { + Sequence::Single(s) => Self::Single(s), + Sequence::Pair(s1, s2) => Self::Dual(s1, s2), + } + } +} + +#[derive(Debug)] +pub(crate) enum PredictInput { + Single(Sequence), + Batch(Vec), +} + +impl<'de> Deserialize<'de> for PredictInput { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + #[derive(Deserialize)] + #[serde(untagged)] + enum Internal { + Single(String), + Multiple(Vec), + } + + struct PredictInputVisitor; + + impl<'de> Visitor<'de> for PredictInputVisitor { + type Value = PredictInput; + + fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result { + formatter.write_str( + "a string, \ + a pair of strings [string, string] \ + or a batch of mixed strings and pairs [[string], [string, string], ...]", + ) + } + + fn visit_str(self, v: &str) -> Result + where + E: de::Error, + { + Ok(PredictInput::Single(Sequence::Single(v.to_string()))) + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: SeqAccess<'de>, + { + let sequence_from_vec = |mut value: Vec| { + // Validate that value is correct + match value.len() { + 1 => Ok(Sequence::Single(value.pop().unwrap())), + 2 => { + // Second element is last + let second = value.pop().unwrap(); + let first = value.pop().unwrap(); + Ok(Sequence::Pair(first, second)) + } + // Sequence can only be a single string or a pair of strings + _ => Err(de::Error::invalid_length(value.len(), &self)), + } + }; + + // Get first element + // This will determine if input is a batch or not + let s = match seq + .next_element::()? + .ok_or_else(|| de::Error::invalid_length(0, &self))? + { + // Input is not a batch + // Return early + Internal::Single(value) => { + // Option get second element + let second = seq.next_element()?; + + if seq.next_element::()?.is_some() { + // Error as we do not accept > 2 elements + return Err(de::Error::invalid_length(3, &self)); + } + + if let Some(second) = second { + // Second element exists + // This is a pair + return Ok(PredictInput::Single(Sequence::Pair(value, second))); + } else { + // Second element does not exist + return Ok(PredictInput::Single(Sequence::Single(value))); + } + } + // Input is a batch + Internal::Multiple(value) => sequence_from_vec(value), + }?; + + let mut batch = Vec::with_capacity(32); + // Push first sequence + batch.push(s); + + // Iterate on all sequences + while let Some(value) = seq.next_element::>()? { + // Validate sequence + let s = sequence_from_vec(value)?; + // Push to batch + batch.push(s); + } + Ok(PredictInput::Batch(batch)) + } + } + + deserializer.deserialize_any(PredictInputVisitor) + } +} + +impl<'__s> ToSchema<'__s> for PredictInput { + fn schema() -> (&'__s str, RefOr) { + ( + "PredictInput", + utoipa::openapi::OneOfBuilder::new() + .item( + utoipa::openapi::ObjectBuilder::new() + .schema_type(utoipa::openapi::SchemaType::String) + .description(Some("A single string")), + ) + .item( + utoipa::openapi::ArrayBuilder::new() + .items( + utoipa::openapi::ObjectBuilder::new() + .schema_type(utoipa::openapi::SchemaType::String), + ) + .description(Some("A pair of strings")) + .min_items(Some(2)) + .max_items(Some(2)), + ) + .item( + utoipa::openapi::ArrayBuilder::new().items( + utoipa::openapi::OneOfBuilder::new() + .item( + utoipa::openapi::ArrayBuilder::new() + .items( + utoipa::openapi::ObjectBuilder::new() + .schema_type(utoipa::openapi::SchemaType::String), + ) + .description(Some("A single string")) + .min_items(Some(1)) + .max_items(Some(1)), + ) + .item( + utoipa::openapi::ArrayBuilder::new() + .items( + utoipa::openapi::ObjectBuilder::new() + .schema_type(utoipa::openapi::SchemaType::String), + ) + .description(Some("A pair of strings")) + .min_items(Some(2)) + .max_items(Some(2)), + ) + ).description(Some("A batch")), + ) + .description(Some( + "Model input. \ + Can be either a single string, a pair of strings or a batch of mixed single and pairs \ + of strings.", + )) + .example(Some(json!("What is Deep Learning?"))) + .into(), + ) + } +} + +#[derive(Deserialize, ToSchema)] +pub(crate) struct PredictRequest { + pub inputs: PredictInput, + #[serde(default)] + #[schema(default = "false", example = "false")] + pub truncate: bool, + #[serde(default)] + #[schema(default = "false", example = "false")] + pub raw_scores: bool, +} + +#[derive(Serialize, ToSchema)] +pub(crate) struct Prediction { + #[schema(example = "0.5")] + pub score: f32, + #[schema(example = "admiration")] + pub label: String, +} + +#[derive(Serialize, ToSchema)] +#[serde(untagged)] +pub(crate) enum PredictResponse { + Single(Vec), + Batch(Vec>), +} + +#[derive(Deserialize, ToSchema)] +pub(crate) struct RerankRequest { + #[schema(example = "What is Deep Learning?")] + pub query: String, + #[schema(example = json!(["Deep Learning is ..."]))] + pub texts: Vec, + #[serde(default)] + #[schema(default = "false", example = "false")] + pub truncate: bool, + #[serde(default)] + #[schema(default = "false", example = "false")] + pub raw_scores: bool, + #[serde(default)] + #[schema(default = "false", example = "false")] + pub return_text: bool, +} + +#[derive(Serialize, ToSchema)] +pub(crate) struct Rank { + #[schema(example = "0")] + pub index: usize, + #[schema(nullable = true, example = "Deep Learning is ...", default = "null")] + #[serde(skip_serializing_if = "Option::is_none")] + pub text: Option, + #[schema(example = "1.0")] + pub score: f32, +} + +#[derive(Serialize, ToSchema)] +pub(crate) struct RerankResponse(pub Vec); + +#[derive(Deserialize, ToSchema)] +#[serde(untagged)] +pub(crate) enum Input { + Single(String), + Batch(Vec), +} + +#[derive(Deserialize, ToSchema)] +pub(crate) struct OpenAICompatRequest { + pub input: Input, + #[allow(dead_code)] + #[schema(nullable = true, example = "null")] + pub model: Option, + #[allow(dead_code)] + #[schema(nullable = true, example = "null")] + pub user: Option, +} + +#[derive(Serialize, ToSchema)] +pub(crate) struct OpenAICompatEmbedding { + #[schema(example = "embedding")] + pub object: &'static str, + #[schema(example = json!(["0.0", "1.0", "2.0"]))] + pub embedding: Vec, + #[schema(example = "0")] + pub index: usize, +} + +#[derive(Serialize, ToSchema)] +pub(crate) struct OpenAICompatUsage { + #[schema(example = "512")] + pub prompt_tokens: usize, + #[schema(example = "512")] + pub total_tokens: usize, +} + +#[derive(Serialize, ToSchema)] +pub(crate) struct OpenAICompatResponse { + #[schema(example = "list")] + pub object: &'static str, + pub data: Vec, + #[schema(example = "thenlper/gte-base")] + pub model: String, + pub usage: OpenAICompatUsage, +} + +#[derive(Deserialize, ToSchema)] +pub(crate) struct EmbedRequest { + pub inputs: Input, + #[serde(default)] + #[schema(default = "false", example = "false")] + pub truncate: bool, + #[serde(default = "default_normalize")] + #[schema(default = "true", example = "true")] + pub normalize: bool, +} + +fn default_normalize() -> bool { + true +} + +#[derive(Serialize, ToSchema)] +#[schema(example = json!([["0.0", "1.0", "2.0"]]))] +pub(crate) struct EmbedResponse(pub Vec>); + +#[derive(Serialize, ToSchema)] +pub(crate) enum ErrorType { + Unhealthy, + Backend, + Overloaded, + Validation, + Tokenizer, +} + +#[derive(Serialize, ToSchema)] +pub(crate) struct ErrorResponse { + pub error: String, + pub error_type: ErrorType, +} + +#[derive(Serialize, ToSchema)] +pub(crate) struct OpenAICompatErrorResponse { + pub message: String, + pub code: u16, + #[serde(rename(serialize = "type"))] + pub error_type: ErrorType, +} diff --git a/router/src/lib.rs b/router/src/lib.rs index 9822310d..948b7fe4 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1,13 +1,9 @@ /// Text Embedding Inference Webserver -pub mod server; +mod http; -use serde::de::{SeqAccess, Visitor}; -use serde::{de, Deserialize, Deserializer, Serialize}; -use serde_json::json; +pub use http::server; +use serde::Serialize; use std::collections::HashMap; -use std::fmt::Formatter; -use text_embeddings_core::tokenization::EncodingInput; -use utoipa::openapi::{RefOr, Schema}; use utoipa::ToSchema; #[derive(Clone, Debug, Serialize, ToSchema)] @@ -62,334 +58,3 @@ pub struct Info { #[schema(nullable = true, example = "null")] pub docker_label: Option<&'static str>, } - -#[derive(Debug)] -pub(crate) enum Sequence { - Single(String), - Pair(String, String), -} - -impl Sequence { - pub(crate) fn count_chars(&self) -> usize { - match self { - Sequence::Single(s) => s.chars().count(), - Sequence::Pair(s1, s2) => s1.chars().count() + s2.chars().count(), - } - } -} - -impl From for EncodingInput { - fn from(value: Sequence) -> Self { - match value { - Sequence::Single(s) => Self::Single(s), - Sequence::Pair(s1, s2) => Self::Dual(s1, s2), - } - } -} - -#[derive(Debug)] -pub(crate) enum PredictInput { - Single(Sequence), - Batch(Vec), -} - -impl<'de> Deserialize<'de> for PredictInput { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - #[derive(Deserialize)] - #[serde(untagged)] - enum Internal { - Single(String), - Multiple(Vec), - } - - struct PredictInputVisitor; - - impl<'de> Visitor<'de> for PredictInputVisitor { - type Value = PredictInput; - - fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result { - formatter.write_str( - "a string, \ - a pair of strings [string, string] \ - or a batch of mixed strings and pairs [[string], [string, string], ...]", - ) - } - - fn visit_str(self, v: &str) -> Result - where - E: de::Error, - { - Ok(PredictInput::Single(Sequence::Single(v.to_string()))) - } - - fn visit_seq(self, mut seq: A) -> Result - where - A: SeqAccess<'de>, - { - let sequence_from_vec = |mut value: Vec| { - // Validate that value is correct - match value.len() { - 1 => Ok(Sequence::Single(value.pop().unwrap())), - 2 => { - // Second element is last - let second = value.pop().unwrap(); - let first = value.pop().unwrap(); - Ok(Sequence::Pair(first, second)) - } - // Sequence can only be a single string or a pair of strings - _ => Err(de::Error::invalid_length(value.len(), &self)), - } - }; - - // Get first element - // This will determine if input is a batch or not - let s = match seq - .next_element::()? - .ok_or_else(|| de::Error::invalid_length(0, &self))? - { - // Input is not a batch - // Return early - Internal::Single(value) => { - // Option get second element - let second = seq.next_element()?; - - if seq.next_element::()?.is_some() { - // Error as we do not accept > 2 elements - return Err(de::Error::invalid_length(3, &self)); - } - - if let Some(second) = second { - // Second element exists - // This is a pair - return Ok(PredictInput::Single(Sequence::Pair(value, second))); - } else { - // Second element does not exist - return Ok(PredictInput::Single(Sequence::Single(value))); - } - } - // Input is a batch - Internal::Multiple(value) => sequence_from_vec(value), - }?; - - let mut batch = Vec::with_capacity(32); - // Push first sequence - batch.push(s); - - // Iterate on all sequences - while let Some(value) = seq.next_element::>()? { - // Validate sequence - let s = sequence_from_vec(value)?; - // Push to batch - batch.push(s); - } - Ok(PredictInput::Batch(batch)) - } - } - - deserializer.deserialize_any(PredictInputVisitor) - } -} - -impl<'__s> ToSchema<'__s> for PredictInput { - fn schema() -> (&'__s str, RefOr) { - ( - "PredictInput", - utoipa::openapi::OneOfBuilder::new() - .item( - utoipa::openapi::ObjectBuilder::new() - .schema_type(utoipa::openapi::SchemaType::String) - .description(Some("A single string")), - ) - .item( - utoipa::openapi::ArrayBuilder::new() - .items( - utoipa::openapi::ObjectBuilder::new() - .schema_type(utoipa::openapi::SchemaType::String), - ) - .description(Some("A pair of strings")) - .min_items(Some(2)) - .max_items(Some(2)), - ) - .item( - utoipa::openapi::ArrayBuilder::new().items( - utoipa::openapi::OneOfBuilder::new() - .item( - utoipa::openapi::ArrayBuilder::new() - .items( - utoipa::openapi::ObjectBuilder::new() - .schema_type(utoipa::openapi::SchemaType::String), - ) - .description(Some("A single string")) - .min_items(Some(1)) - .max_items(Some(1)), - ) - .item( - utoipa::openapi::ArrayBuilder::new() - .items( - utoipa::openapi::ObjectBuilder::new() - .schema_type(utoipa::openapi::SchemaType::String), - ) - .description(Some("A pair of strings")) - .min_items(Some(2)) - .max_items(Some(2)), - ) - ).description(Some("A batch")), - ) - .description(Some( - "Model input. \ - Can be either a single string, a pair of strings or a batch of mixed single and pairs \ - of strings.", - )) - .example(Some(json!("What is Deep Learning?"))) - .into(), - ) - } -} - -#[derive(Deserialize, ToSchema)] -pub(crate) struct PredictRequest { - pub inputs: PredictInput, - #[serde(default)] - #[schema(default = "false", example = "false")] - pub truncate: bool, - #[serde(default)] - #[schema(default = "false", example = "false")] - pub raw_scores: bool, -} - -#[derive(Serialize, ToSchema)] -pub(crate) struct Prediction { - #[schema(example = "0.5")] - score: f32, - #[schema(example = "admiration")] - label: String, -} - -#[derive(Serialize, ToSchema)] -#[serde(untagged)] -pub(crate) enum PredictResponse { - Single(Vec), - Batch(Vec>), -} - -#[derive(Deserialize, ToSchema)] -pub(crate) struct RerankRequest { - #[schema(example = "What is Deep Learning?")] - pub query: String, - #[schema(example = json!(["Deep Learning is ..."]))] - pub texts: Vec, - #[serde(default)] - #[schema(default = "false", example = "false")] - pub truncate: bool, - #[serde(default)] - #[schema(default = "false", example = "false")] - pub raw_scores: bool, - #[serde(default)] - #[schema(default = "false", example = "false")] - pub return_text: bool, -} - -#[derive(Serialize, ToSchema)] -pub(crate) struct Rank { - #[schema(example = "0")] - pub index: usize, - #[schema(nullable = true, example = "Deep Learning is ...", default = "null")] - #[serde(skip_serializing_if = "Option::is_none")] - pub text: Option, - #[schema(example = "1.0")] - pub score: f32, -} - -#[derive(Serialize, ToSchema)] -pub(crate) struct RerankResponse(Vec); - -#[derive(Deserialize, ToSchema)] -#[serde(untagged)] -pub(crate) enum Input { - Single(String), - Batch(Vec), -} - -#[derive(Deserialize, ToSchema)] -pub(crate) struct OpenAICompatRequest { - pub input: Input, - #[allow(dead_code)] - #[schema(nullable = true, example = "null")] - model: Option, - #[allow(dead_code)] - #[schema(nullable = true, example = "null")] - user: Option, -} - -#[derive(Serialize, ToSchema)] -pub(crate) struct OpenAICompatEmbedding { - #[schema(example = "embedding")] - object: &'static str, - #[schema(example = json!([0.0, 1.0, 2.0]))] - embedding: Vec, - #[schema(example = "0")] - index: usize, -} - -#[derive(Serialize, ToSchema)] -pub(crate) struct OpenAICompatUsage { - #[schema(example = "512")] - prompt_tokens: usize, - #[schema(example = "512")] - total_tokens: usize, -} - -#[derive(Serialize, ToSchema)] -pub(crate) struct OpenAICompatResponse { - #[schema(example = "list")] - object: &'static str, - data: Vec, - #[schema(example = "thenlper/gte-base")] - model: String, - usage: OpenAICompatUsage, -} - -#[derive(Deserialize, ToSchema)] -pub(crate) struct EmbedRequest { - pub inputs: Input, - #[serde(default)] - #[schema(default = "false", example = "false")] - pub truncate: bool, - #[serde(default = "default_normalize")] - #[schema(default = "true", example = "true")] - pub normalize: bool, -} - -fn default_normalize() -> bool { - true -} - -#[derive(Serialize, ToSchema)] -#[schema(example = json!([[0.0, 1.0, 2.0]]))] -pub(crate) struct EmbedResponse(Vec>); - -#[derive(Serialize, ToSchema)] -pub(crate) enum ErrorType { - Unhealthy, - Backend, - Overloaded, - Validation, - Tokenizer, -} - -#[derive(Serialize, ToSchema)] -pub(crate) struct ErrorResponse { - pub error: String, - pub error_type: ErrorType, -} - -#[derive(Serialize, ToSchema)] -pub(crate) struct OpenAICompatErrorResponse { - pub message: String, - pub code: u16, - #[serde(rename(serialize = "type"))] - pub error_type: ErrorType, -} From b856ec71d01fe663516bab848f3ccdc1a51ea77b Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 23 Nov 2023 15:34:23 +0100 Subject: [PATCH 02/12] refacto cors --- Cargo.lock | 35 ----------------------------------- router/Cargo.toml | 17 ++++++++--------- router/src/http/server.rs | 15 ++++++++++++++- router/src/lib.rs | 36 +++++++++++++++++++++++++++++------- router/src/main.rs | 20 ++------------------ 5 files changed, 53 insertions(+), 70 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 023dc833..77e4eb1e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -116,28 +116,6 @@ dependencies = [ "backtrace", ] -[[package]] -name = "async-stream" -version = "0.3.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd56dd203fef61ac097dd65721a419ddccb106b2d2b70ba60a6b529f03961a51" -dependencies = [ - "async-stream-impl", - "futures-core", - "pin-project-lite", -] - -[[package]] -name = "async-stream-impl" -version = "0.3.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.39", -] - [[package]] name = "async-trait" version = "0.1.74" @@ -3126,7 +3104,6 @@ name = "text-embeddings-router" version = "0.5.0" dependencies = [ "anyhow", - "async-stream", "axum", "axum-tracing-opentelemetry", "clap", @@ -3148,7 +3125,6 @@ dependencies = [ "tokio", "tower-http", "tracing", - "tracing-chrome", "tracing-opentelemetry 0.21.0", "tracing-subscriber", "utoipa", @@ -3462,17 +3438,6 @@ dependencies = [ "syn 2.0.39", ] -[[package]] -name = "tracing-chrome" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "496b3cd5447f7ff527bbbf19b071ad542a000adf297d4127078b4dfdb931f41a" -dependencies = [ - "serde_json", - "tracing-core", - "tracing-subscriber", -] - [[package]] name = "tracing-core" version = "0.1.32" diff --git a/router/Cargo.toml b/router/Cargo.toml index a6b7494a..596b038d 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -16,13 +16,12 @@ path = "src/main.rs" [dependencies] anyhow = "1.0.71" -async-stream = "0.3.3" -axum = { version = "0.6.4", features = ["json"] } -axum-tracing-opentelemetry = "0.14.1" +axum = { version = "0.6.4", features = ["json"], optional = true } +axum-tracing-opentelemetry = { version = "0.14.1", optional = true } text-embeddings-backend = { path = "../backends", features = ["clap"] } text-embeddings-core = { path = "../core" } clap = { version = "4.1.4", features = ["derive", "env"] } -futures = "^0.3" +futures = { version = "^0.3", optional = true } init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] } hf-hub = { version = "0.3.0", features = ["tokio"] } num_cpus = "1.16.0" @@ -36,20 +35,20 @@ serde_json = "1.0.93" thiserror = "1.0.38" tokenizers = { version = "0.15.0", default-features=false, features=["onig", "esaxx_fast"] } tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } -tower-http = { version = "0.4.0", features = ["cors"] } +tower-http = { version = "0.4.0", features = ["cors"], optional = true } tracing = "0.1.37" -tracing-chrome = "0.7.1" tracing-opentelemetry = "0.21.0" tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] } -utoipa = { version = "4.0.0", features = ["axum_extras"] } -utoipa-swagger-ui = { version = "4.0.0", features = ["axum"] } +utoipa = { version = "4.0.0", features = ["axum_extras"], optional = true } +utoipa-swagger-ui = { version = "4.0.0", features = ["axum"], optional = true } veil = "0.1.6" [build-dependencies] vergen = { version = "8.0.0", features = ["build", "git", "gitcl"] } [features] -default = ["candle"] +default = ["candle", "http"] +http = ["dep:axum", "dep:axum-tracing-opentelemetry", "dep:futures", "dep:tower-http", "dep:utoipa", "dep:utoipa-swagger-ui"] mkl = ["text-embeddings-backend/mkl"] mkl-dynamic = ["text-embeddings-backend/mkl-dynamic"] accelerate = ["text-embeddings-backend/accelerate"] diff --git a/router/src/http/server.rs b/router/src/http/server.rs index a7643ced..999ad620 100644 --- a/router/src/http/server.rs +++ b/router/src/http/server.rs @@ -1,3 +1,4 @@ +use std::env; /// HTTP Server logic use crate::http::types::{ EmbedRequest, EmbedResponse, ErrorResponse, ErrorType, Input, OpenAICompatEmbedding, @@ -23,6 +24,7 @@ use tower_http::cors::{AllowOrigin, CorsLayer}; use tracing::instrument; use utoipa::OpenApi; use utoipa_swagger_ui::SwaggerUi; +use axum::http::HeaderValue; ///Text Embeddings Inference endpoint info #[utoipa::path( @@ -876,7 +878,7 @@ pub async fn run( infer: Infer, info: Info, addr: SocketAddr, - allow_origin: Option, + // allow_origin: Option, ) -> Result<(), axum::BoxError> { // OpenAPI documentation #[derive(OpenApi)] @@ -927,6 +929,17 @@ pub async fn run( )] struct ApiDoc; + // CORS allowed origins + // map to go inside the option and then map to parse from String to HeaderValue + // Finally, convert to AllowOrigin + let allow_origin: Option = env::var("CORS_ALLOW_ORIGIN").ok().map(|cors_allow_origin| { + let cors_allow_origin = cors_allow_origin.split(","); + AllowOrigin::list( + cors_allow_origin + .map(|origin| origin.parse::().unwrap()), + ) + }); + // Duration buckets let duration_matcher = Matcher::Suffix(String::from("duration")); let n_duration_buckets = 35; diff --git a/router/src/lib.rs b/router/src/lib.rs index 948b7fe4..cdc2513f 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1,18 +1,38 @@ /// Text Embedding Inference Webserver -mod http; -pub use http::server; use serde::Serialize; use std::collections::HashMap; -use utoipa::ToSchema; +use std::net::SocketAddr; +use text_embeddings_core::infer::Infer; + +#[cfg(feature = "http")] +mod http; + +pub async fn run( + infer: Infer, + info: Info, + addr: SocketAddr, +) -> Result<(), BoxError> { + if cfg!(feature = "http") { + #[cfg(feature = "http")] + { + return http::server::run(infer, info, addr).await; + } + } + panic!(); +} + +pub type BoxError = Box; -#[derive(Clone, Debug, Serialize, ToSchema)] +#[derive(Clone, Debug, Serialize)] +#[cfg_attr(feature = "http", derive(utoipa::ToSchema))] pub struct EmbeddingModel { #[schema(example = "cls")] pub pooling: String, } -#[derive(Clone, Debug, Serialize, ToSchema)] +#[derive(Clone, Debug, Serialize)] +#[cfg_attr(feature = "http", derive(utoipa::ToSchema))] pub struct ClassifierModel { #[schema(example = json!({"0": "LABEL"}))] pub id2label: HashMap, @@ -20,14 +40,16 @@ pub struct ClassifierModel { pub label2id: HashMap, } -#[derive(Clone, Debug, Serialize, ToSchema)] +#[derive(Clone, Debug, Serialize)] +#[cfg_attr(feature = "http", derive(utoipa::ToSchema))] #[serde(rename_all = "lowercase")] pub enum ModelType { Classifier(ClassifierModel), Embedding(EmbeddingModel), } -#[derive(Clone, Debug, Serialize, ToSchema)] +#[derive(Clone, Debug, Serialize)] +#[cfg_attr(feature = "http", derive(utoipa::ToSchema))] pub struct Info { /// Model info #[schema(example = "thenlper/gte-base")] diff --git a/router/src/main.rs b/router/src/main.rs index bbb4c260..b9c162ba 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -1,5 +1,4 @@ use anyhow::{anyhow, Context, Result}; -use axum::http::HeaderValue; use clap::Parser; use hf_hub::api::tokio::ApiBuilder; use hf_hub::{Repo, RepoType}; @@ -18,10 +17,9 @@ use text_embeddings_core::download::{download_artifacts, download_pool_config}; use text_embeddings_core::infer::Infer; use text_embeddings_core::queue::Queue; use text_embeddings_core::tokenization::Tokenization; -use text_embeddings_router::{server, ClassifierModel, EmbeddingModel, Info, ModelType}; +use text_embeddings_router::{ClassifierModel, EmbeddingModel, Info, ModelType}; use tokenizers::decoders::metaspace::PrependScheme; use tokenizers::{PreTokenizerWrapper, Tokenizer}; -use tower_http::cors::AllowOrigin; use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::util::SubscriberInitExt; use tracing_subscriber::{EnvFilter, Layer}; @@ -121,9 +119,6 @@ struct Args { #[clap(long, env)] otlp_endpoint: Option, - - #[clap(long, env)] - cors_allow_origin: Option>, } #[derive(Debug, Deserialize)] @@ -368,21 +363,10 @@ async fn main() -> Result<()> { } }; - // CORS allowed origins - // map to go inside the option and then map to parse from String to HeaderValue - // Finally, convert to AllowOrigin - let cors_allow_origin: Option = args.cors_allow_origin.map(|cors_allow_origin| { - AllowOrigin::list( - cors_allow_origin - .iter() - .map(|origin| origin.parse::().unwrap()), - ) - }); - tracing::info!("Ready"); // Run axum server - server::run(infer, info, addr, cors_allow_origin) + text_embeddings_router::run(infer, info, addr) .await .unwrap(); Ok(()) From 85b09904aac3af35e728f70b2568a6497b6836cd Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 23 Nov 2023 16:00:23 +0100 Subject: [PATCH 03/12] move prometheus builder --- router/src/grpc.rs | 1 + router/src/http/server.rs | 71 +++++++++++---------------------------- router/src/lib.rs | 18 +++++----- router/src/prometheus.rs | 36 ++++++++++++++++++++ 4 files changed, 65 insertions(+), 61 deletions(-) create mode 100644 router/src/grpc.rs create mode 100644 router/src/prometheus.rs diff --git a/router/src/grpc.rs b/router/src/grpc.rs new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/router/src/grpc.rs @@ -0,0 +1 @@ + diff --git a/router/src/http/server.rs b/router/src/http/server.rs index 999ad620..60e28552 100644 --- a/router/src/http/server.rs +++ b/router/src/http/server.rs @@ -1,4 +1,3 @@ -use std::env; /// HTTP Server logic use crate::http::types::{ EmbedRequest, EmbedResponse, ErrorResponse, ErrorType, Input, OpenAICompatEmbedding, @@ -6,14 +5,18 @@ use crate::http::types::{ PredictInput, PredictRequest, PredictResponse, Prediction, Rank, RerankRequest, RerankResponse, Sequence, }; +use crate::prometheus::prometheus_builer; use crate::{ClassifierModel, EmbeddingModel, Info, ModelType}; +use anyhow::Context; use axum::extract::Extension; +use axum::http::HeaderValue; use axum::http::{HeaderMap, Method, StatusCode}; use axum::routing::{get, post}; use axum::{http, Json, Router}; use axum_tracing_opentelemetry::middleware::OtelAxumLayer; use futures::future::join_all; -use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; +use metrics_exporter_prometheus::PrometheusHandle; +use std::env; use std::net::SocketAddr; use std::time::{Duration, Instant}; use text_embeddings_backend::BackendError; @@ -24,7 +27,6 @@ use tower_http::cors::{AllowOrigin, CorsLayer}; use tracing::instrument; use utoipa::OpenApi; use utoipa_swagger_ui::SwaggerUi; -use axum::http::HeaderValue; ///Text Embeddings Inference endpoint info #[utoipa::path( @@ -878,8 +880,7 @@ pub async fn run( infer: Infer, info: Info, addr: SocketAddr, - // allow_origin: Option, -) -> Result<(), axum::BoxError> { +) -> Result<(), anyhow::Error> { // OpenAPI documentation #[derive(OpenApi)] #[openapi( @@ -932,54 +933,20 @@ pub async fn run( // CORS allowed origins // map to go inside the option and then map to parse from String to HeaderValue // Finally, convert to AllowOrigin - let allow_origin: Option = env::var("CORS_ALLOW_ORIGIN").ok().map(|cors_allow_origin| { - let cors_allow_origin = cors_allow_origin.split(","); - AllowOrigin::list( - cors_allow_origin - .map(|origin| origin.parse::().unwrap()), - ) - }); - - // Duration buckets - let duration_matcher = Matcher::Suffix(String::from("duration")); - let n_duration_buckets = 35; - let mut duration_buckets = Vec::with_capacity(n_duration_buckets); - // Minimum duration in seconds - let mut value = 0.00001; - for _ in 0..n_duration_buckets { - // geometric sequence - value *= 1.5; - duration_buckets.push(value); - } - - // Input Length buckets - let input_length_matcher = Matcher::Full(String::from("te_request_input_length")); - let input_length_buckets: Vec = (0..100) - .map(|x| (info.max_input_length as f64 / 100.0) * (x + 1) as f64) - .collect(); - - // Batch size buckets - let batch_size_matcher = Matcher::Full(String::from("te_batch_next_size")); - let batch_size_buckets: Vec = (0..2048).map(|x| (x + 1) as f64).collect(); - - // Batch tokens buckets - let batch_tokens_matcher = Matcher::Full(String::from("te_batch_next_tokens")); - let batch_tokens_buckets: Vec = (0..100_000).map(|x| (x + 1) as f64).collect(); - - // Prometheus handler - let builder = PrometheusBuilder::new() - .set_buckets_for_metric(duration_matcher, &duration_buckets) - .unwrap() - .set_buckets_for_metric(input_length_matcher, &input_length_buckets) - .unwrap() - .set_buckets_for_metric(batch_size_matcher, &batch_size_buckets) - .unwrap() - .set_buckets_for_metric(batch_tokens_matcher, &batch_tokens_buckets) - .unwrap(); - - let prom_handle = builder + let allow_origin: Option = + env::var("CORS_ALLOW_ORIGIN").ok().map(|cors_allow_origin| { + let cors_allow_origin = cors_allow_origin.split(","); + AllowOrigin::list( + cors_allow_origin.map(|origin| origin.parse::().unwrap()), + ) + }); + + let prometheus_builder = + prometheus_builer(info.max_input_length).context("failed to build prometheus exporter")?; + + let prom_handle = prometheus_builder .install_recorder() - .expect("failed to install metrics recorder"); + .context("failed to install metrics recorder")?; // CORS layer let allow_origin = allow_origin.unwrap_or(AllowOrigin::any()); diff --git a/router/src/lib.rs b/router/src/lib.rs index cdc2513f..c2db692b 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1,28 +1,28 @@ +use anyhow::Result; /// Text Embedding Inference Webserver - use serde::Serialize; use std::collections::HashMap; use std::net::SocketAddr; use text_embeddings_core::infer::Infer; +mod prometheus; + #[cfg(feature = "http")] mod http; -pub async fn run( - infer: Infer, - info: Info, - addr: SocketAddr, -) -> Result<(), BoxError> { +#[cfg(feature = "grpc")] +mod grpc; + +pub async fn run(infer: Infer, info: Info, addr: SocketAddr) -> Result<()> { if cfg!(feature = "http") { #[cfg(feature = "http")] { return http::server::run(infer, info, addr).await; } } - panic!(); -} -pub type BoxError = Box; + anyhow::bail!("You must use one of `http` or `grpc`"); +} #[derive(Clone, Debug, Serialize)] #[cfg_attr(feature = "http", derive(utoipa::ToSchema))] diff --git a/router/src/prometheus.rs b/router/src/prometheus.rs new file mode 100644 index 00000000..3c16684e --- /dev/null +++ b/router/src/prometheus.rs @@ -0,0 +1,36 @@ +use metrics_exporter_prometheus::{BuildError, Matcher, PrometheusBuilder}; + +pub fn prometheus_builer(max_input_length: usize) -> Result { + // Duration buckets + let duration_matcher = Matcher::Suffix(String::from("duration")); + let n_duration_buckets = 35; + let mut duration_buckets = Vec::with_capacity(n_duration_buckets); + // Minimum duration in seconds + let mut value = 0.00001; + for _ in 0..n_duration_buckets { + // geometric sequence + value *= 1.5; + duration_buckets.push(value); + } + + // Input Length buckets + let input_length_matcher = Matcher::Full(String::from("te_request_input_length")); + let input_length_buckets: Vec = (0..100) + .map(|x| (max_input_length as f64 / 100.0) * (x + 1) as f64) + .collect(); + + // Batch size buckets + let batch_size_matcher = Matcher::Full(String::from("te_batch_next_size")); + let batch_size_buckets: Vec = (0..2048).map(|x| (x + 1) as f64).collect(); + + // Batch tokens buckets + let batch_tokens_matcher = Matcher::Full(String::from("te_batch_next_tokens")); + let batch_tokens_buckets: Vec = (0..100_000).map(|x| (x + 1) as f64).collect(); + + // Prometheus handler + PrometheusBuilder::new() + .set_buckets_for_metric(duration_matcher, &duration_buckets)? + .set_buckets_for_metric(input_length_matcher, &input_length_buckets)? + .set_buckets_for_metric(batch_size_matcher, &batch_size_buckets)? + .set_buckets_for_metric(batch_tokens_matcher, &batch_tokens_buckets) +} From 113df898b200a3cfb79329ffd5c774f7e9e40c37 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 23 Nov 2023 18:31:57 +0100 Subject: [PATCH 04/12] grpc support --- Cargo.lock | 192 +++++++++- backends/src/lib.rs | 18 +- core/src/infer.rs | 2 +- load_tests/load.js | 4 +- load_tests/load_grpc.js | 66 ++++ proto/text_embeddings.proto | 80 ++++ router/Cargo.toml | 23 +- router/build.rs | 19 + router/src/grpc.rs | 1 - router/src/grpc/mod.rs | 4 + router/src/grpc/server.rs | 575 ++++++++++++++++++++++++++++ router/src/{http.rs => http/mod.rs} | 0 router/src/http/server.rs | 103 ++--- router/src/http/types.rs | 16 +- router/src/lib.rs | 82 +++- router/src/main.rs | 33 +- router/src/shutdown.rs | 29 ++ 17 files changed, 1091 insertions(+), 156 deletions(-) create mode 100644 load_tests/load_grpc.js create mode 100644 proto/text_embeddings.proto delete mode 100644 router/src/grpc.rs create mode 100644 router/src/grpc/mod.rs create mode 100644 router/src/grpc/server.rs rename router/src/{http.rs => http/mod.rs} (100%) create mode 100644 router/src/shutdown.rs diff --git a/Cargo.lock b/Cargo.lock index 77e4eb1e..225b8faa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -116,6 +116,28 @@ dependencies = [ "backtrace", ] +[[package]] +name = "async-stream" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd56dd203fef61ac097dd65721a419ddccb106b2d2b70ba60a6b529f03961a51" +dependencies = [ + "async-stream-impl", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.39", +] + [[package]] name = "async-trait" version = "0.1.74" @@ -205,12 +227,12 @@ name = "backend-grpc-client" version = "0.5.0" dependencies = [ "grpc-metadata", - "prost", - "prost-build", + "prost 0.11.9", + "prost-build 0.11.9", "thiserror", "tokio", - "tonic", - "tonic-build", + "tonic 0.9.2", + "tonic-build 0.9.2", "tower", "tracing", ] @@ -1164,7 +1186,7 @@ name = "grpc-metadata" version = "0.1.0" dependencies = [ "opentelemetry 0.19.0", - "tonic", + "tonic 0.9.2", "tracing", "tracing-opentelemetry 0.19.0", ] @@ -2019,10 +2041,10 @@ dependencies = [ "opentelemetry-semantic-conventions", "opentelemetry_api 0.20.0", "opentelemetry_sdk 0.20.0", - "prost", + "prost 0.11.9", "thiserror", "tokio", - "tonic", + "tonic 0.9.2", ] [[package]] @@ -2033,8 +2055,8 @@ checksum = "b1e3f814aa9f8c905d0ee4bde026afd3b2577a97c10e1699912e3e44f0c4cbeb" dependencies = [ "opentelemetry_api 0.20.0", "opentelemetry_sdk 0.20.0", - "prost", - "tonic", + "prost 0.11.9", + "tonic 0.9.2", ] [[package]] @@ -2253,6 +2275,16 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "prettyplease" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae005bd773ab59b4725093fd7df83fd7892f7d8eafb48dbd7de6e024e4215f9d" +dependencies = [ + "proc-macro2", + "syn 2.0.39", +] + [[package]] name = "proc-macro-error" version = "1.0.4" @@ -2293,7 +2325,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b82eaa1d779e9a4bc1c3217db8ffbeabaae1dca241bf70183242128d48681cd" dependencies = [ "bytes", - "prost-derive", + "prost-derive 0.11.9", +] + +[[package]] +name = "prost" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "146c289cda302b98a28d40c8b3b90498d6e526dd24ac2ecea73e4e491685b94a" +dependencies = [ + "bytes", + "prost-derive 0.12.3", ] [[package]] @@ -2309,15 +2351,37 @@ dependencies = [ "log", "multimap", "petgraph", - "prettyplease", - "prost", - "prost-types", + "prettyplease 0.1.25", + "prost 0.11.9", + "prost-types 0.11.9", "regex", "syn 1.0.109", "tempfile", "which", ] +[[package]] +name = "prost-build" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c55e02e35260070b6f716a2423c2ff1c3bb1642ddca6f99e1f26d06268a0e2d2" +dependencies = [ + "bytes", + "heck", + "itertools 0.11.0", + "log", + "multimap", + "once_cell", + "petgraph", + "prettyplease 0.2.15", + "prost 0.12.3", + "prost-types 0.12.3", + "regex", + "syn 2.0.39", + "tempfile", + "which", +] + [[package]] name = "prost-derive" version = "0.11.9" @@ -2331,13 +2395,35 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "prost-derive" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "efb6c9a1dd1def8e2124d17e83a20af56f1570d6c2d2bd9e266ccb768df3840e" +dependencies = [ + "anyhow", + "itertools 0.11.0", + "proc-macro2", + "quote", + "syn 2.0.39", +] + [[package]] name = "prost-types" version = "0.11.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "213622a1460818959ac1181aaeb2dc9c7f63df720db7d788b3e24eacd1983e13" dependencies = [ - "prost", + "prost 0.11.9", +] + +[[package]] +name = "prost-types" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "193898f59edcf43c26227dcd4c8427f00d99d61e95dcde58dabd49fa291d470e" +dependencies = [ + "prost 0.12.3", ] [[package]] @@ -3115,6 +3201,8 @@ dependencies = [ "num_cpus", "opentelemetry 0.20.0", "opentelemetry-otlp", + "prost 0.12.3", + "prost-types 0.12.3", "reqwest", "serde", "serde_json", @@ -3123,6 +3211,10 @@ dependencies = [ "thiserror", "tokenizers", "tokio", + "tonic 0.10.2", + "tonic-build 0.10.2", + "tonic-health", + "tonic-reflection", "tower-http", "tracing", "tracing-opentelemetry 0.21.0", @@ -3343,7 +3435,34 @@ dependencies = [ "hyper-timeout", "percent-encoding", "pin-project", - "prost", + "prost 0.11.9", + "tokio", + "tokio-stream", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tonic" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d560933a0de61cf715926b9cac824d4c883c2c43142f787595e48280c40a1d0e" +dependencies = [ + "async-stream", + "async-trait", + "axum", + "base64 0.21.5", + "bytes", + "h2", + "http", + "http-body", + "hyper", + "hyper-timeout", + "percent-encoding", + "pin-project", + "prost 0.12.3", "tokio", "tokio-stream", "tower", @@ -3358,13 +3477,52 @@ version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a6fdaae4c2c638bb70fe42803a26fbd6fc6ac8c72f5c59f67ecc2a2dcabf4b07" dependencies = [ - "prettyplease", + "prettyplease 0.1.25", "proc-macro2", - "prost-build", + "prost-build 0.11.9", "quote", "syn 1.0.109", ] +[[package]] +name = "tonic-build" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d021fc044c18582b9a2408cd0dd05b1596e3ecdb5c4df822bb0183545683889" +dependencies = [ + "prettyplease 0.2.15", + "proc-macro2", + "prost-build 0.12.3", + "quote", + "syn 2.0.39", +] + +[[package]] +name = "tonic-health" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f80db390246dfb46553481f6024f0082ba00178ea495dbb99e70ba9a4fafb5e1" +dependencies = [ + "async-stream", + "prost 0.12.3", + "tokio", + "tokio-stream", + "tonic 0.10.2", +] + +[[package]] +name = "tonic-reflection" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fa37c513df1339d197f4ba21d28c918b9ef1ac1768265f11ecb6b7f1cba1b76" +dependencies = [ + "prost 0.12.3", + "prost-types 0.12.3", + "tokio", + "tokio-stream", + "tonic 0.10.2", +] + [[package]] name = "tower" version = "0.4.13" diff --git a/backends/src/lib.rs b/backends/src/lib.rs index a0321020..c3620938 100644 --- a/backends/src/lib.rs +++ b/backends/src/lib.rs @@ -60,7 +60,7 @@ impl Backend { #[instrument(skip(self))] pub async fn health(&self) -> Result<(), BackendError> { - let result = if *self.health_receiver.borrow() { + if *self.health_receiver.borrow() { // The backend is healthy. Only do a basic health check by calling the // the underlying health method. @@ -86,9 +86,7 @@ impl Backend { ModelType::Classifier => self.predict(batch).await.map(|_| ()), ModelType::Embedding(_) => self.embed(batch).await.map(|_| ()), } - }; - - result + } } #[instrument(skip(self))] @@ -103,11 +101,9 @@ impl Backend { self.backend_sender .send(BackendCommand::Embed(batch, Span::current(), sender)) .expect("No backend receiver. This is a bug."); - let result = receiver.await.expect( + receiver.await.expect( "Backend blocking task dropped the sender without send a response. This is a bug.", - ); - - result + ) } #[instrument(skip_all)] @@ -117,11 +113,9 @@ impl Backend { self.backend_sender .send(BackendCommand::Predict(batch, Span::current(), sender)) .expect("No backend receiver. This is a bug."); - let result = receiver.await.expect( + receiver.await.expect( "Backend blocking task dropped the sender without send a response. This is a bug.", - ); - - result + ) } } diff --git a/core/src/infer.rs b/core/src/infer.rs index d58c03d1..4232ea96 100644 --- a/core/src/infer.rs +++ b/core/src/infer.rs @@ -8,7 +8,7 @@ use tokio::sync::{mpsc, oneshot, watch, Notify, OwnedSemaphorePermit, Semaphore} use tracing::{instrument, Span}; /// Inference struct -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct Infer { tokenization: Tokenization, queue: Queue, diff --git a/load_tests/load.js b/load_tests/load.js index 7a945cce..a537c389 100644 --- a/load_tests/load.js +++ b/load_tests/load.js @@ -26,8 +26,8 @@ export const options = { load_test: { executor: 'constant-arrival-rate', duration: '30s', - preAllocatedVUs: 10000, - rate: 9000, + preAllocatedVUs: 5000, + rate: 1000, timeUnit: '1s', gracefulStop: '1s', }, diff --git a/load_tests/load_grpc.js b/load_tests/load_grpc.js new file mode 100644 index 00000000..80abbd96 --- /dev/null +++ b/load_tests/load_grpc.js @@ -0,0 +1,66 @@ +import {check} from 'k6'; +import grpc from 'k6/net/grpc'; +import {Trend} from 'k6/metrics'; + +const host = __ENV.HOST || '127.0.0.1:3000'; + +const totalTime = new Trend('total_time', true); +const tokenizationTIme = new Trend('tokenization_time', true); +const queueTime = new Trend('queue_time', true); +const inferenceTime = new Trend('inference_time', true); + +export const inputs = 'A path from a point approximately 330 metres east of the most south westerleasterly corner of Unit 4 Foundry Industrial Estate, then proceeding in a generally east-north-east direction for approximately 64 metres to a point approximately 282 metres east-south-east of the most easterly corner of Unit 2 Foundry Industrial Estate, Victoria Street, Widnes and approximately 259 metres east of the most southerly corner of Unit 4 Foundry Industrial Estate, Victoria Street, then proceeding in a generally east-north-east direction for approximately 350 metres to a point approximately 3 metres west-north-west of the most north westerly corner of the boundary fence of the scrap metal yard on the south side of Cornubia Road, Widnes, and approximately 47 metres west-south-west of the stub end of Cornubia Road be diverted to a 3 metre wide path from a point approximately 183 metres east-south-east of the most easterly corner of Unit 5 Foundry Industrial Estate, Victoria Street and approximately 272 metres east of the most north-easterly corner of 26 Ann Street West, Widnes, then proceeding in a generally north easterly direction for approximately 58 metres to a point approximately 216 metres east-south-east of the most easterly corner of Unit 4 Foundry Industrial Estate, Victoria Street and approximately 221 metres east of the most southerly corner of Unit 5 Foundry Industrial Estate, Victoria Street, then proceeding in a generally easterly direction for approximately 45 metres to a point approximately 265 metres east-south-east of the most north-easterly corner of Unit 3 Foundry Industrial Estate, Victoria Street and approximately 265 metres east of the most southerly corner of Unit 5 Foundry Industrial Estate, Victoria Street, then proceeding in a generally east-south-east direction for approximately 102 metres to a point approximately 366 metres east-south-east of the most easterly corner of Unit 3 Foundry Industrial Estate, Victoria Street and approximately 463 metres east of the most north easterly corner of 22 Ann Street West, Widnes, then proceeding in a generally north-north-easterly direction for approximately 19 metres to a point approximately 368 metres east-south-east of the most easterly corner of Unit 3 Foundry Industrial Estate, Victoria Street and approximately 512 metres east of the most south easterly corner of 17 Batherton Close, Widnes then proceeding in a generally east-south, easterly direction for approximately 16 metres to a point approximately 420 metres east-south-east of the most southerly corner of Unit 2 Foundry'; + +export const options = { + thresholds: { + http_req_failed: ['rate==0'], + }, + scenarios: { + // throughput: { + // executor: 'shared-iterations', + // vus: 1000, + // iterations: 1000, + // maxDuration: '2m', + // gracefulStop: '1s', + // }, + load_test: { + executor: 'constant-arrival-rate', + duration: '30s', + preAllocatedVUs: 5000, + rate: 1000, + timeUnit: '1s', + gracefulStop: '1s', + }, + }, +}; + + +const client = new grpc.Client(); + +client.load(['definitions'], '../../proto/text_embeddings.proto'); + +export default function () { + client.connect(host, { + plaintext: true + }); + + const payload = { + inputs: inputs, + truncate: true, + }; + + const res = client.invoke('text_embeddings.v1.TextEmbeddings/Embed', payload); + + check(res, { + 'status is OK': (r) => r && r.status === grpc.StatusOK, + }); + + if (res.status === grpc.StatusOK) { + totalTime.add(res.headers["x-total-time"]); + tokenizationTIme.add(res.headers["x-tokenization-time"]); + queueTime.add(res.headers["x-queue-time"]); + inferenceTime.add(res.headers["x-inference-time"]); + } else { + console.log(res.error); + } +} diff --git a/proto/text_embeddings.proto b/proto/text_embeddings.proto new file mode 100644 index 00000000..b9557bfd --- /dev/null +++ b/proto/text_embeddings.proto @@ -0,0 +1,80 @@ +syntax = "proto3"; + +package text_embeddings.v1; + +service TextEmbeddings { + rpc Info (InfoRequest) returns (InfoResponse) { + option idempotency_level = IDEMPOTENT; + }; + + rpc Embed(EmbedRequest) returns (EmbedResponse); + rpc Predict(PredictRequest) returns (PredictResponse); + rpc Rerank(RerankRequest) returns (RerankResponse); +} + +message InfoRequest {} + +enum ModelType { + MODEL_TYPE_EMBEDDING = 0; + MODEL_TYPE_CLASSIFIER = 1; + MODEL_TYPE_RERANKER = 2; +} + +message InfoResponse { + string version = 1; + optional string sha = 2; + optional string docker_label = 3; + string model_id = 4; + optional string model_sha = 5; + string model_dtype = 6; + ModelType model_type = 7; + uint32 max_concurrent_requests = 8; + uint32 max_input_length = 9; + uint32 max_batch_tokens = 10; + optional uint32 max_batch_requests = 11; + uint32 max_client_batch_size = 12; + uint32 tokenization_workers = 13; +} + +message EmbedRequest { + string inputs = 1; + bool truncate = 2; + bool normalize = 3; +} + +message EmbedResponse { + repeated float embeddings = 1; +} + +message PredictRequest { + string inputs = 1; + bool truncate = 2; + bool raw_scores = 3; +} + +message Prediction { + float score = 1; + string label = 2; +} + +message PredictResponse { + repeated Prediction predictions = 1; +} + +message RerankRequest { + string query = 1; + repeated string texts = 2; + bool truncate = 3; + bool raw_scores = 4; + bool return_text = 5; +} + +message Rank { + uint32 index = 1; + optional string text = 2; + float score = 3; +} + +message RerankResponse { + repeated Rank ranks = 1; +} \ No newline at end of file diff --git a/router/Cargo.toml b/router/Cargo.toml index 596b038d..1716e248 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -16,12 +16,10 @@ path = "src/main.rs" [dependencies] anyhow = "1.0.71" -axum = { version = "0.6.4", features = ["json"], optional = true } -axum-tracing-opentelemetry = { version = "0.14.1", optional = true } text-embeddings-backend = { path = "../backends", features = ["clap"] } text-embeddings-core = { path = "../core" } clap = { version = "4.1.4", features = ["derive", "env"] } -futures = { version = "^0.3", optional = true } +futures = "^0.3" init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] } hf-hub = { version = "0.3.0", features = ["tokio"] } num_cpus = "1.16.0" @@ -35,20 +33,33 @@ serde_json = "1.0.93" thiserror = "1.0.38" tokenizers = { version = "0.15.0", default-features=false, features=["onig", "esaxx_fast"] } tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } -tower-http = { version = "0.4.0", features = ["cors"], optional = true } tracing = "0.1.37" tracing-opentelemetry = "0.21.0" tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] } +veil = "0.1.6" + +# HTTP dependencies +axum = { version = "0.6.4", features = ["json"], optional = true } +axum-tracing-opentelemetry = { version = "0.14.1", optional = true } +tower-http = { version = "0.4.0", features = ["cors"], optional = true } utoipa = { version = "4.0.0", features = ["axum_extras"], optional = true } utoipa-swagger-ui = { version = "4.0.0", features = ["axum"], optional = true } -veil = "0.1.6" + +# gRPC dependencies +prost = { version = "0.12.1", optional = true } +prost-types = { version = "0.12.1", optional = true } +tonic = { version = "0.10.2", optional = true } +tonic-health = { version = "0.10.2", optional = true } +tonic-reflection = { version = "0.10.2", optional = true } [build-dependencies] vergen = { version = "8.0.0", features = ["build", "git", "gitcl"] } +tonic-build = { version = "0.10.2", optional = true } [features] default = ["candle", "http"] -http = ["dep:axum", "dep:axum-tracing-opentelemetry", "dep:futures", "dep:tower-http", "dep:utoipa", "dep:utoipa-swagger-ui"] +http = ["dep:axum", "dep:axum-tracing-opentelemetry", "dep:tower-http", "dep:utoipa", "dep:utoipa-swagger-ui"] +grpc = ["metrics-exporter-prometheus/http-listener", "dep:prost", "dep:prost-types", "dep:tonic", "dep:tonic-health", "dep:tonic-reflection", "dep:tonic-build"] mkl = ["text-embeddings-backend/mkl"] mkl-dynamic = ["text-embeddings-backend/mkl-dynamic"] accelerate = ["text-embeddings-backend/accelerate"] diff --git a/router/build.rs b/router/build.rs index f5eb8a26..cda55c66 100644 --- a/router/build.rs +++ b/router/build.rs @@ -22,5 +22,24 @@ fn main() -> Result<(), Box> { println!("cargo:rustc-env=DOCKER_LABEL={label}"); } + #[cfg(feature = "grpc")] + { + use std::env; + use std::fs; + use std::path::PathBuf; + + fs::create_dir("src/grpc/pb").unwrap_or(()); + + let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap()); + tonic_build::configure() + .build_client(false) + .build_server(true) + .file_descriptor_set_path(out_dir.join("descriptor.bin")) + .out_dir("src/grpc/pb") + .include_file("mod.rs") + .compile(&["../proto/text_embeddings.proto"], &["../proto"]) + .unwrap_or_else(|e| panic!("protobuf compilation failed: {}", e)); + } + Ok(()) } diff --git a/router/src/grpc.rs b/router/src/grpc.rs deleted file mode 100644 index 8b137891..00000000 --- a/router/src/grpc.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/router/src/grpc/mod.rs b/router/src/grpc/mod.rs new file mode 100644 index 00000000..31a2cacf --- /dev/null +++ b/router/src/grpc/mod.rs @@ -0,0 +1,4 @@ +mod pb; +pub(crate) mod server; + +use pb::text_embeddings::v1::{text_embeddings_server::TextEmbeddingsServer, *}; diff --git a/router/src/grpc/server.rs b/router/src/grpc/server.rs new file mode 100644 index 00000000..e319d49f --- /dev/null +++ b/router/src/grpc/server.rs @@ -0,0 +1,575 @@ +use crate::grpc::{ + EmbedRequest, EmbedResponse, InfoRequest, InfoResponse, PredictRequest, PredictResponse, + Prediction, Rank, RerankRequest, RerankResponse, +}; +use crate::{grpc, shutdown, ErrorResponse, ErrorType, Info, ModelType}; +use futures::future::join_all; +use metrics_exporter_prometheus::PrometheusBuilder; +use std::net::SocketAddr; +use std::time::{Duration, Instant}; +use text_embeddings_core::infer::Infer; +use tonic::codegen::http::HeaderMap; +use tonic::metadata::MetadataMap; +use tonic::transport::Server; +use tonic::{Code, Extensions, Request, Response, Status}; +use tracing::instrument; + +#[derive(Debug)] +struct TextEmbeddingsService { + infer: Infer, + info: Info, +} + +#[tonic::async_trait] +impl grpc::text_embeddings_server::TextEmbeddings for TextEmbeddingsService { + async fn info(&self, _request: Request) -> Result, Status> { + let model_type = match self.info.model_type { + ModelType::Classifier(_) => grpc::ModelType::Classifier, + ModelType::Embedding(_) => grpc::ModelType::Embedding, + ModelType::Reranker(_) => grpc::ModelType::Reranker, + }; + + Ok(Response::new(InfoResponse { + version: self.info.version.to_string(), + sha: self.info.sha.map(|s| s.to_string()), + docker_label: self.info.docker_label.map(|s| s.to_string()), + model_id: self.info.model_id.clone(), + model_sha: self.info.model_sha.clone(), + model_dtype: self.info.model_dtype.clone(), + model_type: model_type.into(), + max_concurrent_requests: self.info.max_concurrent_requests as u32, + max_input_length: self.info.max_input_length as u32, + max_batch_tokens: self.info.max_batch_tokens as u32, + max_batch_requests: self.info.max_batch_requests.map(|v| v as u32), + max_client_batch_size: self.info.max_client_batch_size as u32, + tokenization_workers: self.info.tokenization_workers as u32, + })) + } + #[instrument( + skip_all, + fields(total_time, tokenization_time, queue_time, inference_time,) + )] + async fn embed( + &self, + request: Request, + ) -> Result, Status> { + let span = tracing::Span::current(); + let start_time = Instant::now(); + + let request = request.into_inner(); + + let ( + compute_chars, + compute_tokens, + tokenization_time, + queue_time, + inference_time, + response, + ) = { + metrics::increment_counter!("te_request_count", "method" => "single"); + + let compute_chars = request.inputs.chars().count(); + + let permit = self + .infer + .try_acquire_permit() + .map_err(ErrorResponse::from)?; + let response = self + .infer + .embed(request.inputs, request.truncate, request.normalize, permit) + .await + .map_err(ErrorResponse::from)?; + + metrics::increment_counter!("te_request_success", "method" => "single"); + + ( + compute_chars, + response.prompt_tokens, + response.tokenization, + response.queue, + response.inference, + EmbedResponse { + embeddings: response.results, + }, + ) + }; + + let total_time = start_time.elapsed(); + + // Tracing metadata + span.record("total_time", format!("{total_time:?}")); + span.record("tokenization_time", format!("{tokenization_time:?}")); + span.record("queue_time", format!("{queue_time:?}")); + span.record("inference_time", format!("{inference_time:?}")); + + // Headers + let mut headers = HeaderMap::new(); + headers.insert("x-compute-type", "gpu+optimized".parse().unwrap()); + headers.insert( + "x-compute-time", + total_time.as_millis().to_string().parse().unwrap(), + ); + headers.insert( + "x-compute-characters", + compute_chars.to_string().parse().unwrap(), + ); + headers.insert( + "x-compute-tokens", + compute_tokens.to_string().parse().unwrap(), + ); + headers.insert( + "x-total-time", + total_time.as_millis().to_string().parse().unwrap(), + ); + headers.insert( + "x-tokenization-time", + tokenization_time.as_millis().to_string().parse().unwrap(), + ); + headers.insert( + "x-queue-time", + queue_time.as_millis().to_string().parse().unwrap(), + ); + headers.insert( + "x-inference-time", + inference_time.as_millis().to_string().parse().unwrap(), + ); + + // Metrics + metrics::histogram!("te_request_duration", total_time.as_secs_f64()); + metrics::histogram!( + "te_request_tokenization_duration", + tokenization_time.as_secs_f64() + ); + metrics::histogram!("te_request_queue_duration", queue_time.as_secs_f64()); + metrics::histogram!( + "te_request_inference_duration", + inference_time.as_secs_f64() + ); + + tracing::info!("Success"); + + Ok(Response::from_parts( + MetadataMap::from_headers(headers), + response, + Extensions::default(), + )) + } + #[instrument( + skip_all, + fields(total_time, tokenization_time, queue_time, inference_time,) + )] + async fn predict( + &self, + request: Request, + ) -> Result, Status> { + let span = tracing::Span::current(); + let start_time = Instant::now(); + + let request = request.into_inner(); + + // Closure for predict + let predict_inner = move |inputs: String, + truncate: bool, + raw_scores: bool, + infer: Infer, + info: Info| async move { + let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?; + let response = infer + .predict(inputs, truncate, raw_scores, permit) + .await + .map_err(ErrorResponse::from)?; + + let id2label = match &info.model_type { + ModelType::Classifier(classifier) => &classifier.id2label, + ModelType::Reranker(classifier) => &classifier.id2label, + _ => panic!(), + }; + + let mut predictions: Vec = { + // Map score to label + response + .results + .into_iter() + .enumerate() + .map(|(i, s)| Prediction { + score: s, + label: id2label.get(&i.to_string()).unwrap().clone(), + }) + .collect() + }; + // Reverse sort + predictions.sort_by(|x, y| x.score.partial_cmp(&y.score).unwrap()); + predictions.reverse(); + + Ok::<(usize, Duration, Duration, Duration, Vec), ErrorResponse>(( + response.prompt_tokens, + response.tokenization, + response.queue, + response.inference, + predictions, + )) + }; + + let ( + compute_chars, + compute_tokens, + tokenization_time, + queue_time, + inference_time, + predictions, + ) = { + metrics::increment_counter!("te_request_count", "method" => "single"); + + let compute_chars = request.inputs.chars().count(); + let (prompt_tokens, tokenization, queue, inference, predictions) = predict_inner( + request.inputs, + request.truncate, + request.raw_scores, + self.infer.clone(), + self.info.clone(), + ) + .await?; + + metrics::increment_counter!("te_request_success", "method" => "single"); + + ( + compute_chars, + prompt_tokens, + tokenization, + queue, + inference, + predictions, + ) + }; + + let total_time = start_time.elapsed(); + + // Tracing metadata + span.record("total_time", format!("{total_time:?}")); + span.record("tokenization_time", format!("{tokenization_time:?}")); + span.record("queue_time", format!("{queue_time:?}")); + span.record("inference_time", format!("{inference_time:?}")); + + // Headers + let mut headers = HeaderMap::new(); + headers.insert("x-compute-type", "gpu+optimized".parse().unwrap()); + headers.insert( + "x-compute-time", + total_time.as_millis().to_string().parse().unwrap(), + ); + headers.insert( + "x-compute-characters", + compute_chars.to_string().parse().unwrap(), + ); + headers.insert( + "x-compute-tokens", + compute_tokens.to_string().parse().unwrap(), + ); + headers.insert( + "x-total-time", + total_time.as_millis().to_string().parse().unwrap(), + ); + headers.insert( + "x-tokenization-time", + tokenization_time.as_millis().to_string().parse().unwrap(), + ); + headers.insert( + "x-queue-time", + queue_time.as_millis().to_string().parse().unwrap(), + ); + headers.insert( + "x-inference-time", + inference_time.as_millis().to_string().parse().unwrap(), + ); + + // Metrics + metrics::histogram!("te_request_duration", total_time.as_secs_f64()); + metrics::histogram!( + "te_request_tokenization_duration", + tokenization_time.as_secs_f64() + ); + metrics::histogram!("te_request_queue_duration", queue_time.as_secs_f64()); + metrics::histogram!( + "te_request_inference_duration", + inference_time.as_secs_f64() + ); + + tracing::info!("Success"); + + Ok(Response::from_parts( + MetadataMap::from_headers(headers), + PredictResponse { predictions }, + Extensions::default(), + )) + } + #[instrument( + skip_all, + fields(total_time, tokenization_time, queue_time, inference_time,) + )] + async fn rerank( + &self, + request: Request, + ) -> Result, Status> { + let span = tracing::Span::current(); + let start_time = Instant::now(); + + let request = request.into_inner(); + + match &self.info.model_type { + ModelType::Classifier(_) => { + metrics::increment_counter!("te_request_failure", "err" => "model_type"); + let message = "model is not a re-ranker model".to_string(); + tracing::error!("{message}"); + Err(Status::new(Code::FailedPrecondition, message)) + } + ModelType::Reranker(_) => Ok(()), + ModelType::Embedding(_) => { + metrics::increment_counter!("te_request_failure", "err" => "model_type"); + let message = "model is not a classifier model".to_string(); + tracing::error!("{message}"); + Err(Status::new(Code::FailedPrecondition, message)) + } + }?; + + // Closure for rerank + let rerank_inner = move |query: String, + text: String, + truncate: bool, + raw_scores: bool, + infer: Infer| async move { + let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?; + + let response = infer + .predict((query, text), truncate, raw_scores, permit) + .await + .map_err(ErrorResponse::from)?; + + let score = response.results[0]; + + Ok::<(usize, Duration, Duration, Duration, f32), ErrorResponse>(( + response.prompt_tokens, + response.tokenization, + response.queue, + response.inference, + score, + )) + }; + + let ( + compute_chars, + compute_tokens, + tokenization_time, + queue_time, + inference_time, + response, + ) = { + metrics::increment_counter!("te_request_count", "method" => "batch"); + + let batch_size = request.texts.len(); + if batch_size > self.info.max_client_batch_size { + let message = format!( + "batch size {batch_size} > maximum allowed batch size {}", + self.info.max_client_batch_size + ); + tracing::error!("{message}"); + let err = ErrorResponse { + error: message, + error_type: ErrorType::Validation, + }; + metrics::increment_counter!("te_request_failure", "err" => "batch_size"); + Err(err)?; + } + + let mut futures = Vec::with_capacity(batch_size); + let query_chars = request.query.chars().count(); + let mut compute_chars = query_chars * batch_size; + + for text in &request.texts { + compute_chars += text.chars().count(); + let local_infer = self.infer.clone(); + futures.push(rerank_inner( + request.query.clone(), + text.clone(), + request.truncate, + request.raw_scores, + local_infer, + )) + } + let results = join_all(futures) + .await + .into_iter() + .collect::, ErrorResponse>>( + )?; + + let mut ranks = Vec::with_capacity(batch_size); + let mut total_tokenization_time = 0; + let mut total_queue_time = 0; + let mut total_inference_time = 0; + let mut total_compute_tokens = 0; + + for (index, r) in results.into_iter().enumerate() { + total_compute_tokens += r.0; + total_tokenization_time += r.1.as_nanos() as u64; + total_queue_time += r.2.as_nanos() as u64; + total_inference_time += r.3.as_nanos() as u64; + let text = if request.return_text { + Some(request.texts[index].clone()) + } else { + None + }; + + ranks.push(Rank { + index: index as u32, + text, + score: r.4, + }) + } + + // Reverse sort + ranks.sort_by(|x, y| x.score.partial_cmp(&y.score).unwrap()); + ranks.reverse(); + + let batch_size = batch_size as u64; + + metrics::increment_counter!("te_request_success", "method" => "batch"); + + ( + compute_chars, + total_compute_tokens, + Duration::from_nanos(total_tokenization_time / batch_size), + Duration::from_nanos(total_queue_time / batch_size), + Duration::from_nanos(total_inference_time / batch_size), + RerankResponse { ranks }, + ) + }; + + let total_time = start_time.elapsed(); + + // Tracing metadata + span.record("total_time", format!("{total_time:?}")); + span.record("tokenization_time", format!("{tokenization_time:?}")); + span.record("queue_time", format!("{queue_time:?}")); + span.record("inference_time", format!("{inference_time:?}")); + + // Headers + let mut headers = HeaderMap::new(); + headers.insert("x-compute-type", "gpu+optimized".parse().unwrap()); + headers.insert( + "x-compute-time", + total_time.as_millis().to_string().parse().unwrap(), + ); + headers.insert( + "x-compute-characters", + compute_chars.to_string().parse().unwrap(), + ); + headers.insert( + "x-compute-tokens", + compute_tokens.to_string().parse().unwrap(), + ); + headers.insert( + "x-total-time", + total_time.as_millis().to_string().parse().unwrap(), + ); + headers.insert( + "x-tokenization-time", + tokenization_time.as_millis().to_string().parse().unwrap(), + ); + headers.insert( + "x-queue-time", + queue_time.as_millis().to_string().parse().unwrap(), + ); + headers.insert( + "x-inference-time", + inference_time.as_millis().to_string().parse().unwrap(), + ); + + // Metrics + metrics::histogram!("te_request_duration", total_time.as_secs_f64()); + metrics::histogram!( + "te_request_tokenization_duration", + tokenization_time.as_secs_f64() + ); + metrics::histogram!("te_request_queue_duration", queue_time.as_secs_f64()); + metrics::histogram!( + "te_request_inference_duration", + inference_time.as_secs_f64() + ); + + tracing::info!("Success"); + + Ok(Response::from_parts( + MetadataMap::from_headers(headers), + response, + Extensions::default(), + )) + } +} + +pub async fn run( + infer: Infer, + info: Info, + addr: SocketAddr, + prom_builder: PrometheusBuilder, +) -> Result<(), anyhow::Error> { + prom_builder.install()?; + tracing::info!("Serving Prometheus metrics: 0.0.0.0:9000"); + + // Liveness service + let (mut health_reporter, health_service) = tonic_health::server::health_reporter(); + health_reporter + .set_serving::>() + .await; + + let mut health_watcher = infer.health_watcher(); + + tokio::spawn(async move { + while health_watcher.changed().await.is_ok() { + let health = *health_watcher.borrow_and_update(); + match health { + true => { + health_reporter + .set_serving::>() + .await + } + false => { + health_reporter + .set_not_serving::>() + .await + } + } + } + }); + + // gRPC reflection + let file_descriptor_set: &[u8] = tonic::include_file_descriptor_set!("descriptor"); + let reflection_service = tonic_reflection::server::Builder::configure() + .register_encoded_file_descriptor_set(file_descriptor_set) + .build()?; + + // Main service + let service = TextEmbeddingsService { infer, info }; + + // Create gRPC server + tracing::info!("Starting gRPC server: {}", &addr); + Server::builder() + .add_service(health_service) + .add_service(reflection_service) + .add_service(grpc::TextEmbeddingsServer::new(service)) + .serve_with_shutdown(addr, shutdown::shutdown_signal()) + .await?; + Ok(()) +} + +impl From for Status { + fn from(value: ErrorResponse) -> Self { + let code = match value.error_type { + ErrorType::Unhealthy => Code::Unavailable, + ErrorType::Backend => Code::FailedPrecondition, + ErrorType::Overloaded => Code::ResourceExhausted, + ErrorType::Validation => Code::InvalidArgument, + ErrorType::Tokenizer => Code::FailedPrecondition, + }; + + Status::new(code, value.error) + } +} diff --git a/router/src/http.rs b/router/src/http/mod.rs similarity index 100% rename from router/src/http.rs rename to router/src/http/mod.rs diff --git a/router/src/http/server.rs b/router/src/http/server.rs index 60e28552..0d947b8a 100644 --- a/router/src/http/server.rs +++ b/router/src/http/server.rs @@ -1,12 +1,10 @@ /// HTTP Server logic use crate::http::types::{ - EmbedRequest, EmbedResponse, ErrorResponse, ErrorType, Input, OpenAICompatEmbedding, - OpenAICompatErrorResponse, OpenAICompatRequest, OpenAICompatResponse, OpenAICompatUsage, - PredictInput, PredictRequest, PredictResponse, Prediction, Rank, RerankRequest, RerankResponse, - Sequence, + EmbedRequest, EmbedResponse, Input, OpenAICompatEmbedding, OpenAICompatErrorResponse, + OpenAICompatRequest, OpenAICompatResponse, OpenAICompatUsage, PredictInput, PredictRequest, + PredictResponse, Prediction, Rank, RerankRequest, RerankResponse, Sequence, }; -use crate::prometheus::prometheus_builer; -use crate::{ClassifierModel, EmbeddingModel, Info, ModelType}; +use crate::{shutdown, ClassifierModel, EmbeddingModel, ErrorResponse, ErrorType, Info, ModelType}; use anyhow::Context; use axum::extract::Extension; use axum::http::HeaderValue; @@ -15,14 +13,13 @@ use axum::routing::{get, post}; use axum::{http, Json, Router}; use axum_tracing_opentelemetry::middleware::OtelAxumLayer; use futures::future::join_all; -use metrics_exporter_prometheus::PrometheusHandle; +use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle}; use std::env; use std::net::SocketAddr; use std::time::{Duration, Instant}; use text_embeddings_backend::BackendError; use text_embeddings_core::infer::{Infer, InferResponse}; use text_embeddings_core::TextEmbeddingsError; -use tokio::signal; use tower_http::cors::{AllowOrigin, CorsLayer}; use tracing::instrument; use utoipa::OpenApi; @@ -106,6 +103,7 @@ async fn predict( let id2label = match &info.model_type { ModelType::Classifier(classifier) => &classifier.id2label, + ModelType::Reranker(classifier) => &classifier.id2label, _ => panic!(), }; @@ -309,17 +307,14 @@ async fn rerank( let start_time = Instant::now(); match &info.model_type { - ModelType::Classifier(classifier) => { - if classifier.id2label.len() > 1 { - metrics::increment_counter!("te_request_failure", "err" => "model_type"); - let message = "model is not a re-ranker model".to_string(); - Err(TextEmbeddingsError::Backend(BackendError::Inference( - message, - ))) - } else { - Ok(()) - } + ModelType::Classifier(_) => { + metrics::increment_counter!("te_request_failure", "err" => "model_type"); + let message = "model is not a re-ranker model".to_string(); + Err(TextEmbeddingsError::Backend(BackendError::Inference( + message, + ))) } + ModelType::Reranker(_) => Ok(()), ModelType::Embedding(_) => { metrics::increment_counter!("te_request_failure", "err" => "model_type"); let message = "model is not a classifier model".to_string(); @@ -880,6 +875,7 @@ pub async fn run( infer: Infer, info: Info, addr: SocketAddr, + prom_builder: PrometheusBuilder, ) -> Result<(), anyhow::Error> { // OpenAPI documentation #[derive(OpenApi)] @@ -935,16 +931,13 @@ pub async fn run( // Finally, convert to AllowOrigin let allow_origin: Option = env::var("CORS_ALLOW_ORIGIN").ok().map(|cors_allow_origin| { - let cors_allow_origin = cors_allow_origin.split(","); + let cors_allow_origin = cors_allow_origin.split(','); AllowOrigin::list( cors_allow_origin.map(|origin| origin.parse::().unwrap()), ) }); - let prometheus_builder = - prometheus_builer(info.max_input_length).context("failed to build prometheus exporter")?; - - let prom_handle = prometheus_builder + let prom_handle = prom_builder .install_recorder() .context("failed to install metrics recorder")?; @@ -976,16 +969,15 @@ pub async fn run( // Set default routes let app = match &info.model_type { - ModelType::Classifier(classifier) => { - if classifier.id2label.len() > 1 { - app.route("/", post(predict)) - // AWS Sagemaker route - .route("/invocations", post(predict)) - } else { - app.route("/", post(rerank)) - // AWS Sagemaker route - .route("/invocations", post(rerank)) - } + ModelType::Classifier(_) => { + app.route("/", post(predict)) + // AWS Sagemaker route + .route("/invocations", post(predict)) + } + ModelType::Reranker(_) => { + app.route("/", post(rerank)) + // AWS Sagemaker route + .route("/invocations", post(rerank)) } ModelType::Embedding(_) => { app.route("/", post(embed)) @@ -1005,55 +997,12 @@ pub async fn run( axum::Server::bind(&addr) .serve(app.into_make_service()) // Wait until all requests are finished to shut down - .with_graceful_shutdown(shutdown_signal()) + .with_graceful_shutdown(shutdown::shutdown_signal()) .await?; Ok(()) } -/// Shutdown signal handler -async fn shutdown_signal() { - let ctrl_c = async { - signal::ctrl_c() - .await - .expect("failed to install Ctrl+C handler"); - }; - - #[cfg(unix)] - let terminate = async { - signal::unix::signal(signal::unix::SignalKind::terminate()) - .expect("failed to install signal handler") - .recv() - .await; - }; - - #[cfg(not(unix))] - let terminate = std::future::pending::<()>(); - - tokio::select! { - _ = ctrl_c => {}, - _ = terminate => {}, - } - - tracing::info!("signal received, starting graceful shutdown"); - opentelemetry::global::shutdown_tracer_provider(); -} - -impl From for ErrorResponse { - fn from(err: TextEmbeddingsError) -> Self { - let error_type = match err { - TextEmbeddingsError::Tokenizer(_) => ErrorType::Tokenizer, - TextEmbeddingsError::Validation(_) => ErrorType::Validation, - TextEmbeddingsError::Overloaded(_) => ErrorType::Overloaded, - TextEmbeddingsError::Backend(_) => ErrorType::Backend, - }; - Self { - error: err.to_string(), - error_type, - } - } -} - impl From<&ErrorType> for StatusCode { fn from(value: &ErrorType) -> Self { match value { diff --git a/router/src/http/types.rs b/router/src/http/types.rs index 9d004907..336ae811 100644 --- a/router/src/http/types.rs +++ b/router/src/http/types.rs @@ -1,3 +1,4 @@ +use crate::ErrorType; use serde::de::{SeqAccess, Visitor}; use serde::{de, Deserialize, Deserializer, Serialize}; use serde_json::json; @@ -314,21 +315,6 @@ fn default_normalize() -> bool { #[schema(example = json!([["0.0", "1.0", "2.0"]]))] pub(crate) struct EmbedResponse(pub Vec>); -#[derive(Serialize, ToSchema)] -pub(crate) enum ErrorType { - Unhealthy, - Backend, - Overloaded, - Validation, - Tokenizer, -} - -#[derive(Serialize, ToSchema)] -pub(crate) struct ErrorResponse { - pub error: String, - pub error_type: ErrorType, -} - #[derive(Serialize, ToSchema)] pub(crate) struct OpenAICompatErrorResponse { pub message: String, diff --git a/router/src/lib.rs b/router/src/lib.rs index c2db692b..06772883 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -4,6 +4,7 @@ use serde::Serialize; use std::collections::HashMap; use std::net::SocketAddr; use text_embeddings_core::infer::Infer; +use text_embeddings_core::TextEmbeddingsError; mod prometheus; @@ -12,12 +13,22 @@ mod http; #[cfg(feature = "grpc")] mod grpc; +mod shutdown; pub async fn run(infer: Infer, info: Info, addr: SocketAddr) -> Result<()> { + let prom_builder = prometheus::prometheus_builer(info.max_input_length)?; + if cfg!(feature = "http") { #[cfg(feature = "http")] { - return http::server::run(infer, info, addr).await; + return http::server::run(infer, info, addr, prom_builder).await; + } + } + + if cfg!(feature = "grpc") { + #[cfg(feature = "grpc")] + { + return grpc::server::run(infer, info, addr, prom_builder).await; } } @@ -27,16 +38,16 @@ pub async fn run(infer: Infer, info: Info, addr: SocketAddr) -> Result<()> { #[derive(Clone, Debug, Serialize)] #[cfg_attr(feature = "http", derive(utoipa::ToSchema))] pub struct EmbeddingModel { - #[schema(example = "cls")] + #[cfg_attr(feature = "http", schema(example = "cls"))] pub pooling: String, } #[derive(Clone, Debug, Serialize)] #[cfg_attr(feature = "http", derive(utoipa::ToSchema))] pub struct ClassifierModel { - #[schema(example = json!({"0": "LABEL"}))] + #[cfg_attr(feature = "http", schema(example = json!({"0": "LABEL"})))] pub id2label: HashMap, - #[schema(example = json!({"LABEL": 0}))] + #[cfg_attr(feature = "http", schema(example = json!({"LABEL": 0})))] pub label2id: HashMap, } @@ -46,37 +57,76 @@ pub struct ClassifierModel { pub enum ModelType { Classifier(ClassifierModel), Embedding(EmbeddingModel), + Reranker(ClassifierModel), } #[derive(Clone, Debug, Serialize)] #[cfg_attr(feature = "http", derive(utoipa::ToSchema))] pub struct Info { /// Model info - #[schema(example = "thenlper/gte-base")] + #[cfg_attr(feature = "http", schema(example = "thenlper/gte-base"))] pub model_id: String, - #[schema(nullable = true, example = "fca14538aa9956a46526bd1d0d11d69e19b5a101")] + #[cfg_attr( + feature = "http", + schema(nullable = true, example = "fca14538aa9956a46526bd1d0d11d69e19b5a101") + )] pub model_sha: Option, - #[schema(example = "float16")] + #[cfg_attr(feature = "http", schema(example = "float16"))] pub model_dtype: String, pub model_type: ModelType, /// Router Parameters - #[schema(example = "128")] + #[cfg_attr(feature = "http", schema(example = "128"))] pub max_concurrent_requests: usize, - #[schema(example = "512")] + #[cfg_attr(feature = "http", schema(example = "512"))] pub max_input_length: usize, - #[schema(example = "2048")] + #[cfg_attr(feature = "http", schema(example = "2048"))] pub max_batch_tokens: usize, - #[schema(nullable = true, example = "null", default = "null")] + #[cfg_attr( + feature = "http", + schema(nullable = true, example = "null", default = "null") + )] pub max_batch_requests: Option, - #[schema(example = "32")] + #[cfg_attr(feature = "http", schema(example = "32"))] pub max_client_batch_size: usize, - #[schema(example = "4")] + #[cfg_attr(feature = "http", schema(example = "4"))] pub tokenization_workers: usize, /// Router Info - #[schema(example = "0.5.0")] + #[cfg_attr(feature = "http", schema(example = "0.5.0"))] pub version: &'static str, - #[schema(nullable = true, example = "null")] + #[cfg_attr(feature = "http", schema(nullable = true, example = "null"))] pub sha: Option<&'static str>, - #[schema(nullable = true, example = "null")] + #[cfg_attr(feature = "http", schema(nullable = true, example = "null"))] pub docker_label: Option<&'static str>, } + +#[derive(Serialize)] +#[cfg_attr(feature = "http", derive(utoipa::ToSchema))] +pub enum ErrorType { + Unhealthy, + Backend, + Overloaded, + Validation, + Tokenizer, +} + +#[derive(Serialize)] +#[cfg_attr(feature = "http", derive(utoipa::ToSchema))] +pub struct ErrorResponse { + pub error: String, + pub error_type: ErrorType, +} + +impl From for ErrorResponse { + fn from(err: TextEmbeddingsError) -> Self { + let error_type = match err { + TextEmbeddingsError::Tokenizer(_) => ErrorType::Tokenizer, + TextEmbeddingsError::Validation(_) => ErrorType::Validation, + TextEmbeddingsError::Overloaded(_) => ErrorType::Overloaded, + TextEmbeddingsError::Backend(_) => ErrorType::Backend, + }; + Self { + error: err.to_string(), + error_type, + } + } +} diff --git a/router/src/main.rs b/router/src/main.rs index b9c162ba..92dcea61 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -231,14 +231,23 @@ async fn main() -> Result<()> { // Info model type let model_type = match &backend_model_type { - text_embeddings_backend::ModelType::Classifier => ModelType::Classifier(ClassifierModel { - id2label: config + text_embeddings_backend::ModelType::Classifier => { + let id2label = config .id2label - .context("`config.json` does not contain `id2label`")?, - label2id: config - .label2id - .context("`config.json` does not contain `label2id`")?, - }), + .context("`config.json` does not contain `id2label`")?; + let n_classes = id2label.len(); + let classifier_model = ClassifierModel { + id2label, + label2id: config + .label2id + .context("`config.json` does not contain `label2id`")?, + }; + if n_classes > 1 { + ModelType::Classifier(classifier_model) + } else { + ModelType::Reranker(classifier_model) + } + } text_embeddings_backend::ModelType::Embedding(pool) => { ModelType::Embedding(EmbeddingModel { pooling: pool.to_string(), @@ -314,7 +323,7 @@ async fn main() -> Result<()> { dtype.clone(), backend_model_type, args.uds_path, - args.otlp_endpoint, + args.otlp_endpoint.clone(), ) .context("Could not create backend")?; backend @@ -345,7 +354,7 @@ async fn main() -> Result<()> { model_dtype: dtype.to_string(), model_type, max_concurrent_requests: args.max_concurrent_requests, - max_input_length: config.max_position_embeddings, + max_input_length, max_batch_tokens: args.max_batch_tokens, tokenization_workers, max_batch_requests, @@ -369,6 +378,12 @@ async fn main() -> Result<()> { text_embeddings_router::run(infer, info, addr) .await .unwrap(); + + if args.otlp_endpoint.is_some() { + // Shutdown tracer + global::shutdown_tracer_provider(); + } + Ok(()) } diff --git a/router/src/shutdown.rs b/router/src/shutdown.rs new file mode 100644 index 00000000..471eaf14 --- /dev/null +++ b/router/src/shutdown.rs @@ -0,0 +1,29 @@ +use tokio::signal; + +/// Shutdown signal handler +pub(crate) async fn shutdown_signal() { + let ctrl_c = async { + signal::ctrl_c() + .await + .expect("failed to install Ctrl+C handler"); + }; + + #[cfg(unix)] + let terminate = async { + signal::unix::signal(signal::unix::SignalKind::terminate()) + .expect("failed to install signal handler") + .recv() + .await; + }; + + #[cfg(not(unix))] + let terminate = std::future::pending::<()>(); + + tokio::select! { + _ = ctrl_c => {}, + _ = terminate => {}, + } + + tracing::info!("signal received, starting graceful shutdown"); + opentelemetry::global::shutdown_tracer_provider(); +} From 94919a2868b01d504f688561afcbf830dd95fdd7 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Fri, 24 Nov 2023 17:51:31 +0100 Subject: [PATCH 05/12] add images --- .github/workflows/build_75.yaml | 36 +++++++++++++++++++++++++++++++- .github/workflows/build_80.yaml | 33 +++++++++++++++++++++++++++++ .github/workflows/build_86.yaml | 34 ++++++++++++++++++++++++++++++ .github/workflows/build_89.yaml | 34 ++++++++++++++++++++++++++++++ .github/workflows/build_90.yaml | 35 +++++++++++++++++++++++++++++++ .github/workflows/build_cpu.yaml | 33 +++++++++++++++++++++++++++++ Dockerfile | 27 ++++++++++++++++++++++-- Dockerfile-cuda | 30 ++++++++++++++++++++++++-- 8 files changed, 257 insertions(+), 5 deletions(-) diff --git a/.github/workflows/build_75.yaml b/.github/workflows/build_75.yaml index dc6ee6b7..d744f96c 100644 --- a/.github/workflows/build_75.yaml +++ b/.github/workflows/build_75.yaml @@ -77,7 +77,7 @@ tags: | type=semver,pattern=turing-{{version}} type=semver,pattern=turing-{{major}}.{{minor}} - type=raw,value=turing-latest,enable=${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) }} + type=raw,value=turing-latest type=raw,value=turing-sha-${{ env.GITHUB_SHA_SHORT }} - name: Build and push Docker image id: build-and-push-75 @@ -99,3 +99,37 @@ labels: ${{ steps.meta-75.outputs.labels }} cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-75,mode=max cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-75,mode=max + - name: Extract metadata (tags, labels) for Docker + id: meta-75-grpc + uses: docker/metadata-action@v4.3.0 + with: + images: | + registry.internal.huggingface.tech/api-inference/text-embeddings-inference + ghcr.io/huggingface/text-embeddings-inference + flavor: | + latest=false + tags: | + type=semver,pattern=turing-{{version}}+grpc + type=semver,pattern=turing-{{major}}.{{minor}}+grpc + type=raw,value=turing-latest+grpc + type=raw,value=turing-sha-${{ env.GITHUB_SHA_SHORT }}+grpc + - name: Build and push Docker image + id: build-and-push-75-grpc + uses: docker/build-push-action@v4 + with: + context: . + target: grpc + file: Dockerfile-cuda + push: ${{ github.event_name != 'pull_request' }} + platforms: 'linux/amd64' + build-args: | + SCCACHE_GHA_ENABLED=on + ACTIONS_CACHE_URL=${{ env.ACTIONS_CACHE_URL }} + ACTIONS_RUNTIME_TOKEN=${{ env.ACTIONS_RUNTIME_TOKEN }} + CUDA_COMPUTE_CAP=75 + GIT_SHA=${{ env.GITHUB_SHA }} + DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }} + DEFAULT_USE_FLASH_ATTENTION=False + tags: ${{ steps.meta-75-grpc.outputs.tags }} + labels: ${{ steps.meta-75-grpc.outputs.labels }} + cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-75,mode=max diff --git a/.github/workflows/build_80.yaml b/.github/workflows/build_80.yaml index f5d6fe5a..589f1aef 100644 --- a/.github/workflows/build_80.yaml +++ b/.github/workflows/build_80.yaml @@ -98,3 +98,36 @@ labels: ${{ steps.meta-80.outputs.labels }} cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-80,mode=max cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-80,mode=max + - name: Extract metadata (tags, labels) for Docker + id: meta-80-grpc + uses: docker/metadata-action@v4.3.0 + with: + images: | + registry.internal.huggingface.tech/api-inference/text-embeddings-inference + ghcr.io/huggingface/text-embeddings-inference + flavor: | + latest=false + tags: | + type=semver,pattern={{version}}+grpc + type=semver,pattern={{major}}.{{minor}}+grpc + type=raw,value=latest+grpc + type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}+grpc + - name: Build and push Docker image + id: build-and-push-80-grpc + uses: docker/build-push-action@v4 + with: + context: . + target: grpc + file: Dockerfile-cuda + push: ${{ github.event_name != 'pull_request' }} + platforms: 'linux/amd64' + build-args: | + SCCACHE_GHA_ENABLED=on + ACTIONS_CACHE_URL=${{ env.ACTIONS_CACHE_URL }} + ACTIONS_RUNTIME_TOKEN=${{ env.ACTIONS_RUNTIME_TOKEN }} + CUDA_COMPUTE_CAP=80 + GIT_SHA=${{ env.GITHUB_SHA }} + DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }} + tags: ${{ steps.meta-80-grpc.outputs.tags }} + labels: ${{ steps.meta-80-grpc.outputs.labels }} + cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-80,mode=max diff --git a/.github/workflows/build_86.yaml b/.github/workflows/build_86.yaml index bd824414..d799abaf 100644 --- a/.github/workflows/build_86.yaml +++ b/.github/workflows/build_86.yaml @@ -98,3 +98,37 @@ labels: ${{ steps.meta-86.outputs.labels }} cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-86,mode=max cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-86,mode=max + - name: Extract metadata (tags, labels) for Docker + id: meta-86-grpc + uses: docker/metadata-action@v4.3.0 + with: + images: | + registry.internal.huggingface.tech/api-inference/text-embeddings-inference + ghcr.io/huggingface/text-embeddings-inference + flavor: | + latest=false + tags: | + type=semver,pattern=86-{{version}}+grpc + type=semver,pattern=86-{{major}}.{{minor}}+grpc + type=raw,value=86-latest+grpc + type=raw,value=86-sha-${{ env.GITHUB_SHA_SHORT }}+grpc + - name: Build and push Docker image + id: build-and-push-86-grpc + uses: docker/build-push-action@v4 + with: + context: . + target: grpc + file: Dockerfile-cuda + push: ${{ github.event_name != 'pull_request' }} + platforms: 'linux/amd64' + build-args: | + SCCACHE_GHA_ENABLED=on + ACTIONS_CACHE_URL=${{ env.ACTIONS_CACHE_URL }} + ACTIONS_RUNTIME_TOKEN=${{ env.ACTIONS_RUNTIME_TOKEN }} + CUDA_COMPUTE_CAP=86 + GIT_SHA=${{ env.GITHUB_SHA }} + DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }} + tags: ${{ steps.meta-86-grpc.outputs.tags }} + labels: ${{ steps.meta-86-grpc.outputs.labels }} + cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-86,mode=max + diff --git a/.github/workflows/build_89.yaml b/.github/workflows/build_89.yaml index a5a5be7c..5126ab69 100644 --- a/.github/workflows/build_89.yaml +++ b/.github/workflows/build_89.yaml @@ -98,3 +98,37 @@ labels: ${{ steps.meta-89.outputs.labels }} cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-89,mode=max cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-89,mode=max + - name: Extract metadata (tags, labels) for Docker + id: meta-89-grpc + uses: docker/metadata-action@v4.3.0 + with: + images: | + registry.internal.huggingface.tech/api-inference/text-embeddings-inference + ghcr.io/huggingface/text-embeddings-inference + flavor: | + latest=false + tags: | + type=semver,pattern=89-{{version}}+grpc + type=semver,pattern=89-{{major}}.{{minor}}+grpc + type=raw,value=89-latest+grpc + type=raw,value=89-sha-${{ env.GITHUB_SHA_SHORT }}+grpc + - name: Build and push Docker image + id: build-and-push-89-grpc + uses: docker/build-push-action@v4 + with: + context: . + target: grpc + file: Dockerfile-cuda + push: ${{ github.event_name != 'pull_request' }} + platforms: 'linux/amd64' + build-args: | + SCCACHE_GHA_ENABLED=on + ACTIONS_CACHE_URL=${{ env.ACTIONS_CACHE_URL }} + ACTIONS_RUNTIME_TOKEN=${{ env.ACTIONS_RUNTIME_TOKEN }} + CUDA_COMPUTE_CAP=89 + GIT_SHA=${{ env.GITHUB_SHA }} + DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }} + tags: ${{ steps.meta-89-grpc.outputs.tags }} + labels: ${{ steps.meta-89-grpc.outputs.labels }} + cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-89,mode=max + diff --git a/.github/workflows/build_90.yaml b/.github/workflows/build_90.yaml index 9c6f2d6a..63fc3b6f 100644 --- a/.github/workflows/build_90.yaml +++ b/.github/workflows/build_90.yaml @@ -98,4 +98,39 @@ labels: ${{ steps.meta-90.outputs.labels }} cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-90,mode=max cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-90,mode=max + - name: Extract metadata (tags, labels) for Docker + id: meta-90-grpc + uses: docker/metadata-action@v4.3.0 + with: + images: | + registry.internal.huggingface.tech/api-inference/text-embeddings-inference + ghcr.io/huggingface/text-embeddings-inference + flavor: | + latest=false + tags: | + type=semver,pattern=hopper-{{version}}+grpc + type=semver,pattern=hopper-{{major}}.{{minor}}+grpc + type=raw,value=hopper-latest+grpc + type=raw,value=hopper-sha-${{ env.GITHUB_SHA_SHORT }}+grpc + - name: Build and push Docker image + id: build-and-push-90-grpc + uses: docker/build-push-action@v4 + with: + context: . + target: grpc + file: Dockerfile-cuda + push: ${{ github.event_name != 'pull_request' }} + platforms: 'linux/amd64' + build-args: | + SCCACHE_GHA_ENABLED=on + ACTIONS_CACHE_URL=${{ env.ACTIONS_CACHE_URL }} + ACTIONS_RUNTIME_TOKEN=${{ env.ACTIONS_RUNTIME_TOKEN }} + CUDA_COMPUTE_CAP=90 + GIT_SHA=${{ env.GITHUB_SHA }} + DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }} + tags: ${{ steps.meta-90-grpc.outputs.tags }} + labels: ${{ steps.meta-90-grpc.outputs.labels }} + cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-90,mode=max + + diff --git a/.github/workflows/build_cpu.yaml b/.github/workflows/build_cpu.yaml index e0237c81..bc6623c7 100644 --- a/.github/workflows/build_cpu.yaml +++ b/.github/workflows/build_cpu.yaml @@ -97,3 +97,36 @@ labels: ${{ steps.meta-cpu.outputs.labels }} cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-cpu,mode=max cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-cpu,mode=max + - name: Extract metadata (tags, labels) for Docker + id: meta-cpu-grpc + uses: docker/metadata-action@v4.3.0 + with: + images: | + registry.internal.huggingface.tech/api-inference/text-embeddings-inference + ghcr.io/huggingface/text-embeddings-inference + flavor: | + latest=false + tags: | + type=semver,pattern=cpu-{{version}}+grpc + type=semver,pattern=cpu-{{major}}.{{minor}}+grpc + type=raw,value=cpu-latest+grpc + type=raw,value=cpu-sha-${{ env.GITHUB_SHA_SHORT }}+grpc + - name: Build and push Docker image + id: build-and-push-cpu-grpc + uses: docker/build-push-action@v4 + with: + context: . + target: grpc + file: Dockerfile + push: ${{ github.event_name != 'pull_request' }} + platforms: 'linux/amd64' + build-args: | + SCCACHE_GHA_ENABLED=on + ACTIONS_CACHE_URL=${{ env.ACTIONS_CACHE_URL }} + ACTIONS_RUNTIME_TOKEN=${{ env.ACTIONS_RUNTIME_TOKEN }} + GIT_SHA=${{ env.GITHUB_SHA }} + DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }} + tags: ${{ steps.meta-cpu-grpc.outputs.tags }} + labels: ${{ steps.meta-cpu-grpc.outputs.labels }} + cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-cpu,mode=max + diff --git a/Dockerfile b/Dockerfile index f557c1d3..07baafb6 100644 --- a/Dockerfile +++ b/Dockerfile @@ -51,9 +51,23 @@ COPY router router COPY Cargo.toml ./ COPY Cargo.lock ./ +FROM builder as http-builder + RUN cargo build --release --bin text-embeddings-router -F candle -F mkl-dynamic --no-default-features && sccache -s -FROM debian:bookworm-slim +FROM builder as grpc-builder + +RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ + curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ + unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ + unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \ + rm -f $PROTOC_ZIP + +COPY proto proto + +RUN cargo build --release --bin text-embeddings-router -F grpc -F candle -F mkl-dynamic --no-default-features && sccache -s + +FROM debian:bookworm-slim as base ENV HUGGINGFACE_HUB_CACHE=/data \ PORT=80 \ @@ -80,7 +94,16 @@ COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_avx2.so.2 /u COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_avx512.so.2 /usr/local/lib/libmkl_avx512.so.2 COPY --from=builder /usr/src/libfakeintel.so /usr/local/libfakeintel.so -COPY --from=builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router +FROM base as grpc + +COPY --from=grpc-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router + +ENTRYPOINT ["text-embeddings-router"] +CMD ["--json-output"] + +FROM base + +COPY --from=http-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router ENTRYPOINT ["text-embeddings-router"] CMD ["--json-output"] \ No newline at end of file diff --git a/Dockerfile-cuda b/Dockerfile-cuda index 98e7f93c..77d43cd7 100644 --- a/Dockerfile-cuda +++ b/Dockerfile-cuda @@ -70,6 +70,8 @@ COPY router router COPY Cargo.toml ./ COPY Cargo.lock ./ +FROM builder as http-builder + RUN if [ ${CUDA_COMPUTE_CAP} -ge 75 -a ${CUDA_COMPUTE_CAP} -lt 80 ]; \ then \ cargo build --release --bin text-embeddings-router -F candle-cuda-turing -F static-linking --no-default-features && sccache -s; \ @@ -77,7 +79,22 @@ RUN if [ ${CUDA_COMPUTE_CAP} -ge 75 -a ${CUDA_COMPUTE_CAP} -lt 80 ]; \ cargo build --release --bin text-embeddings-router -F candle-cuda -F static-linking --no-default-features && sccache -s; \ fi; -FROM nvidia/cuda:12.0.0-base-ubuntu22.04 +FROM builder as grpc-builder + +RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ + curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ + unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ + unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \ + rm -f $PROTOC_ZIP + +RUN if [ ${CUDA_COMPUTE_CAP} -ge 75 -a ${CUDA_COMPUTE_CAP} -lt 80 ]; \ + then \ + cargo build --release --bin text-embeddings-router -F candle-cuda-turing -F static-linking -F grpc --no-default-features && sccache -s; \ + else \ + cargo build --release --bin text-embeddings-router -F candle-cuda -F static-linking -F grpc --no-default-features && sccache -s; \ + fi; + +FROM nvidia/cuda:12.0.0-base-ubuntu22.04 as base ARG DEFAULT_USE_FLASH_ATTENTION=True @@ -85,7 +102,16 @@ ENV HUGGINGFACE_HUB_CACHE=/data \ PORT=80 \ USE_FLASH_ATTENTION=$DEFAULT_USE_FLASH_ATTENTION -COPY --from=builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router +FROM base as grpc + +COPY --from=grpc-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router + +ENTRYPOINT ["text-embeddings-router"] +CMD ["--json-output"] + +FROM base + +COPY --from=http-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router ENTRYPOINT ["text-embeddings-router"] CMD ["--json-output"] \ No newline at end of file From cb2eeabbdc6d14d0b572a81caec8d20aa8f15f8d Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Fri, 24 Nov 2023 17:54:41 +0100 Subject: [PATCH 06/12] fix openapi --- Cargo.lock | 2 +- router/src/http/types.rs | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 225b8faa..108e1f57 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3456,7 +3456,7 @@ dependencies = [ "base64 0.21.5", "bytes", "h2", - "http", + "http 0.2.11", "http-body", "hyper", "hyper-timeout", diff --git a/router/src/http/types.rs b/router/src/http/types.rs index 336ae811..dbdba0f1 100644 --- a/router/src/http/types.rs +++ b/router/src/http/types.rs @@ -272,7 +272,7 @@ pub(crate) struct OpenAICompatRequest { pub(crate) struct OpenAICompatEmbedding { #[schema(example = "embedding")] pub object: &'static str, - #[schema(example = json!(["0.0", "1.0", "2.0"]))] + #[schema(example = json!([0.0, 1.0, 2.0]))] pub embedding: Vec, #[schema(example = "0")] pub index: usize, @@ -312,7 +312,7 @@ fn default_normalize() -> bool { } #[derive(Serialize, ToSchema)] -#[schema(example = json!([["0.0", "1.0", "2.0"]]))] +#[schema(example = json!([[0.0, 1.0, 2.0]]))] pub(crate) struct EmbedResponse(pub Vec>); #[derive(Serialize, ToSchema)] From dbb232bb039e8037f5cae60c5b4ae4cf8e3afe4c Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Fri, 24 Nov 2023 18:38:28 +0100 Subject: [PATCH 07/12] add unzip --- Dockerfile-cuda | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/Dockerfile-cuda b/Dockerfile-cuda index 77d43cd7..17cfb615 100644 --- a/Dockerfile-cuda +++ b/Dockerfile-cuda @@ -81,6 +81,10 @@ RUN if [ ${CUDA_COMPUTE_CAP} -ge 75 -a ${CUDA_COMPUTE_CAP} -lt 80 ]; \ FROM builder as grpc-builder +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + unzip \ + && rm -rf /var/lib/apt/lists/* + RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ From 1d61277ff07bad4018517ed457fb8b403f3d5d40 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Fri, 24 Nov 2023 19:14:31 +0100 Subject: [PATCH 08/12] add proto --- Dockerfile-cuda | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Dockerfile-cuda b/Dockerfile-cuda index 17cfb615..bad00a70 100644 --- a/Dockerfile-cuda +++ b/Dockerfile-cuda @@ -91,6 +91,8 @@ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \ rm -f $PROTOC_ZIP +COPY proto proto + RUN if [ ${CUDA_COMPUTE_CAP} -ge 75 -a ${CUDA_COMPUTE_CAP} -lt 80 ]; \ then \ cargo build --release --bin text-embeddings-router -F candle-cuda-turing -F static-linking -F grpc --no-default-features && sccache -s; \ From 619ecbf3cc72dd6d1fc5caf24d342f558b21d940 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Sat, 25 Nov 2023 14:07:48 +0100 Subject: [PATCH 09/12] multiple services --- README.md | 20 +++++ proto/{text_embeddings.proto => tei.proto} | 20 +++-- router/build.rs | 2 +- router/src/grpc/mod.rs | 5 +- router/src/grpc/server.rs | 91 +++++++++++++++++++--- 5 files changed, 120 insertions(+), 18 deletions(-) rename proto/{text_embeddings.proto => tei.proto} (84%) diff --git a/README.md b/README.md index 66abc7b4..08f1dc7a 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,7 @@ length of 512 tokens: - [Using Re-rankers models](#using-re-rankers-models) - [Using Sequence Classification models](#using-sequence-classification-models) - [Distributed Tracing](#distributed-tracing) + - [gRPC](#grpc) - [Local Install](#local-install) - [Docker Build](#docker-build) @@ -334,6 +335,25 @@ curl 127.0.0.1:8080/predict \ `text-embeddings-inference` is instrumented with distributed tracing using OpenTelemetry. You can use this feature by setting the address to an OTLP collector with the `--otlp-endpoint` argument. +### gRPC + +`text-embeddings-inference` offers a gRPC API as an alternative to the default HTTP API for high performance +deployments. The API protobuf definition can be found [here](https://github.com/huggingface/text-embeddings-inference/blob/main/proto/tei.proto). + +You can use the gRPC API by adding the `+grpc` tag to any TEI Docker image. For example: + +```shell +model=BAAI/bge-large-en-v1.5 +revision=refs/pr/5 +volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run + +docker run --gpus all -p 8080:80 -v $volume:/data --pull always ghcr.io/huggingface/text-embeddings-inference:0.5+grpc --model-id $model --revision $revision +``` + +```shell +grpcurl -d '{"inputs": "What is Deep Learning"}' -plaintext 0.0.0.0:8080 tei.v1.Embed/Embed +``` + ## Local install ### CPU diff --git a/proto/text_embeddings.proto b/proto/tei.proto similarity index 84% rename from proto/text_embeddings.proto rename to proto/tei.proto index b9557bfd..e5852b8c 100644 --- a/proto/text_embeddings.proto +++ b/proto/tei.proto @@ -1,15 +1,23 @@ syntax = "proto3"; -package text_embeddings.v1; +package tei.v1; -service TextEmbeddings { +service Info { rpc Info (InfoRequest) returns (InfoResponse) { option idempotency_level = IDEMPOTENT; }; +} + +service Embed { + rpc Embed (EmbedRequest) returns (EmbedResponse); +} - rpc Embed(EmbedRequest) returns (EmbedResponse); - rpc Predict(PredictRequest) returns (PredictResponse); - rpc Rerank(RerankRequest) returns (RerankResponse); +service Predict { + rpc Predict (PredictRequest) returns (PredictResponse); +} + +service Rerank { + rpc Rerank (RerankRequest) returns (RerankResponse); } message InfoRequest {} @@ -77,4 +85,4 @@ message Rank { message RerankResponse { repeated Rank ranks = 1; -} \ No newline at end of file +} diff --git a/router/build.rs b/router/build.rs index cda55c66..b690eae0 100644 --- a/router/build.rs +++ b/router/build.rs @@ -37,7 +37,7 @@ fn main() -> Result<(), Box> { .file_descriptor_set_path(out_dir.join("descriptor.bin")) .out_dir("src/grpc/pb") .include_file("mod.rs") - .compile(&["../proto/text_embeddings.proto"], &["../proto"]) + .compile(&["../proto/tei.proto"], &["../proto"]) .unwrap_or_else(|e| panic!("protobuf compilation failed: {}", e)); } diff --git a/router/src/grpc/mod.rs b/router/src/grpc/mod.rs index 31a2cacf..91d4c8d4 100644 --- a/router/src/grpc/mod.rs +++ b/router/src/grpc/mod.rs @@ -1,4 +1,7 @@ mod pb; pub(crate) mod server; -use pb::text_embeddings::v1::{text_embeddings_server::TextEmbeddingsServer, *}; +use pb::tei::v1::{ + embed_server::EmbedServer, info_server::InfoServer, predict_server::PredictServer, + rerank_server::RerankServer, *, +}; diff --git a/router/src/grpc/server.rs b/router/src/grpc/server.rs index e319d49f..dc2b5995 100644 --- a/router/src/grpc/server.rs +++ b/router/src/grpc/server.rs @@ -10,18 +10,20 @@ use std::time::{Duration, Instant}; use text_embeddings_core::infer::Infer; use tonic::codegen::http::HeaderMap; use tonic::metadata::MetadataMap; +use tonic::server::NamedService; use tonic::transport::Server; use tonic::{Code, Extensions, Request, Response, Status}; +use tonic_health::ServingStatus; use tracing::instrument; -#[derive(Debug)] +#[derive(Debug, Clone)] struct TextEmbeddingsService { infer: Infer, info: Info, } #[tonic::async_trait] -impl grpc::text_embeddings_server::TextEmbeddings for TextEmbeddingsService { +impl grpc::info_server::Info for TextEmbeddingsService { async fn info(&self, _request: Request) -> Result, Status> { let model_type = match self.info.model_type { ModelType::Classifier(_) => grpc::ModelType::Classifier, @@ -45,6 +47,10 @@ impl grpc::text_embeddings_server::TextEmbeddings for TextEmbeddingsService { tokenization_workers: self.info.tokenization_workers as u32, })) } +} + +#[tonic::async_trait] +impl grpc::embed_server::Embed for TextEmbeddingsService { #[instrument( skip_all, fields(total_time, tokenization_time, queue_time, inference_time,) @@ -154,6 +160,10 @@ impl grpc::text_embeddings_server::TextEmbeddings for TextEmbeddingsService { Extensions::default(), )) } +} + +#[tonic::async_trait] +impl grpc::predict_server::Predict for TextEmbeddingsService { #[instrument( skip_all, fields(total_time, tokenization_time, queue_time, inference_time,) @@ -302,6 +312,10 @@ impl grpc::text_embeddings_server::TextEmbeddings for TextEmbeddingsService { Extensions::default(), )) } +} + +#[tonic::async_trait] +impl grpc::rerank_server::Rerank for TextEmbeddingsService { #[instrument( skip_all, fields(total_time, tokenization_time, queue_time, inference_time,) @@ -516,27 +530,80 @@ pub async fn run( // Liveness service let (mut health_reporter, health_service) = tonic_health::server::health_reporter(); + // Info is always serving + health_reporter + .set_serving::>() + .await; + // Set all other services to not serving + // Their health will be updated in the task below + health_reporter + .set_not_serving::>() + .await; + health_reporter + .set_not_serving::>() + .await; health_reporter - .set_serving::>() + .set_not_serving::>() .await; + // Backend health watcher let mut health_watcher = infer.health_watcher(); + // Clone model_type and move it to the task + let health_watcher_model_type = info.model_type.clone(); + + // Update services health tokio::spawn(async move { while health_watcher.changed().await.is_ok() { let health = *health_watcher.borrow_and_update(); - match health { - true => { + let status = match health { + true => ServingStatus::Serving, + false => ServingStatus::NotServing, + }; + + // Match on model type and set the health of the correct service(s) + // + // If Reranker, we have both a predict and rerank service + // + // This logic hints back to the user that if they try using the wrong service + // given the model type, it will always return an error. + // + // For example if the model type is `Embedding`, sending requests to `Rerank` will + // always return an `UNIMPLEMENTED` Status and both the `Rerank` and `Predict` services + // will have a `NOT_SERVING` ServingStatus. + match health_watcher_model_type { + ModelType::Classifier(_) => { health_reporter - .set_serving::>() + .set_service_status( + >::NAME, + status, + ) .await } - false => { + ModelType::Embedding(_) => { health_reporter - .set_not_serving::>() + .set_service_status( + >::NAME, + status, + ) .await } - } + ModelType::Reranker(_) => { + // Reranker has both a predict and rerank service + health_reporter + .set_service_status( + >::NAME, + status, + ) + .await; + health_reporter + .set_service_status( + >::NAME, + status, + ) + .await; + } + }; } }); @@ -554,9 +621,13 @@ pub async fn run( Server::builder() .add_service(health_service) .add_service(reflection_service) - .add_service(grpc::TextEmbeddingsServer::new(service)) + .add_service(grpc::InfoServer::new(service.clone())) + .add_service(grpc::EmbedServer::new(service.clone())) + .add_service(grpc::PredictServer::new(service.clone())) + .add_service(grpc::RerankServer::new(service)) .serve_with_shutdown(addr, shutdown::shutdown_signal()) .await?; + Ok(()) } From 858bb4fc17fdc8b729df87f555b05453adb8eb4e Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Sat, 25 Nov 2023 14:50:30 +0100 Subject: [PATCH 10/12] prepare for streaming --- proto/tei.proto | 3 + router/src/grpc/server.rs | 713 ++++++++++++++++++-------------------- 2 files changed, 338 insertions(+), 378 deletions(-) diff --git a/proto/tei.proto b/proto/tei.proto index e5852b8c..01a6995d 100644 --- a/proto/tei.proto +++ b/proto/tei.proto @@ -10,14 +10,17 @@ service Info { service Embed { rpc Embed (EmbedRequest) returns (EmbedResponse); + //rpc EmbedStream (stream EmbedRequest) returns (stream EmbedResponse); } service Predict { rpc Predict (PredictRequest) returns (PredictResponse); + //rpc PredictStream (stream PredictRequest) returns (stream PredictResponse); } service Rerank { rpc Rerank (RerankRequest) returns (RerankResponse); + //rpc RerankStream (stream RerankRequest) returns (stream RerankResponse); } message InfoRequest {} diff --git a/router/src/grpc/server.rs b/router/src/grpc/server.rs index dc2b5995..5fdbe4b5 100644 --- a/router/src/grpc/server.rs +++ b/router/src/grpc/server.rs @@ -7,328 +7,262 @@ use futures::future::join_all; use metrics_exporter_prometheus::PrometheusBuilder; use std::net::SocketAddr; use std::time::{Duration, Instant}; -use text_embeddings_core::infer::Infer; +use text_embeddings_core::infer::{Infer, InferResponse}; use tonic::codegen::http::HeaderMap; use tonic::metadata::MetadataMap; use tonic::server::NamedService; use tonic::transport::Server; use tonic::{Code, Extensions, Request, Response, Status}; use tonic_health::ServingStatus; -use tracing::instrument; - -#[derive(Debug, Clone)] -struct TextEmbeddingsService { - infer: Infer, - info: Info, +use tracing::{instrument, Span}; + +struct ResponseMetadata { + compute_chars: usize, + compute_tokens: usize, + start_time: Instant, + tokenization_time: Duration, + queue_time: Duration, + inference_time: Duration, } -#[tonic::async_trait] -impl grpc::info_server::Info for TextEmbeddingsService { - async fn info(&self, _request: Request) -> Result, Status> { - let model_type = match self.info.model_type { - ModelType::Classifier(_) => grpc::ModelType::Classifier, - ModelType::Embedding(_) => grpc::ModelType::Embedding, - ModelType::Reranker(_) => grpc::ModelType::Reranker, - }; - - Ok(Response::new(InfoResponse { - version: self.info.version.to_string(), - sha: self.info.sha.map(|s| s.to_string()), - docker_label: self.info.docker_label.map(|s| s.to_string()), - model_id: self.info.model_id.clone(), - model_sha: self.info.model_sha.clone(), - model_dtype: self.info.model_dtype.clone(), - model_type: model_type.into(), - max_concurrent_requests: self.info.max_concurrent_requests as u32, - max_input_length: self.info.max_input_length as u32, - max_batch_tokens: self.info.max_batch_tokens as u32, - max_batch_requests: self.info.max_batch_requests.map(|v| v as u32), - max_client_batch_size: self.info.max_client_batch_size as u32, - tokenization_workers: self.info.tokenization_workers as u32, - })) - } -} - -#[tonic::async_trait] -impl grpc::embed_server::Embed for TextEmbeddingsService { - #[instrument( - skip_all, - fields(total_time, tokenization_time, queue_time, inference_time,) - )] - async fn embed( - &self, - request: Request, - ) -> Result, Status> { - let span = tracing::Span::current(); - let start_time = Instant::now(); - - let request = request.into_inner(); - - let ( +impl ResponseMetadata { + fn new(compute_chars: usize, start_time: Instant, response: &InferResponse) -> Self { + Self { compute_chars, - compute_tokens, - tokenization_time, - queue_time, - inference_time, - response, - ) = { - metrics::increment_counter!("te_request_count", "method" => "single"); - - let compute_chars = request.inputs.chars().count(); - - let permit = self - .infer - .try_acquire_permit() - .map_err(ErrorResponse::from)?; - let response = self - .infer - .embed(request.inputs, request.truncate, request.normalize, permit) - .await - .map_err(ErrorResponse::from)?; - - metrics::increment_counter!("te_request_success", "method" => "single"); - - ( - compute_chars, - response.prompt_tokens, - response.tokenization, - response.queue, - response.inference, - EmbedResponse { - embeddings: response.results, - }, - ) - }; - - let total_time = start_time.elapsed(); + compute_tokens: response.prompt_tokens, + start_time, + tokenization_time: response.tokenization, + queue_time: response.queue, + inference_time: response.inference, + } + } + fn record_span(&self, span: &Span) { // Tracing metadata - span.record("total_time", format!("{total_time:?}")); - span.record("tokenization_time", format!("{tokenization_time:?}")); - span.record("queue_time", format!("{queue_time:?}")); - span.record("inference_time", format!("{inference_time:?}")); + span.record("compute_chars", self.compute_chars); + span.record("compute_tokens", self.compute_tokens); + span.record("total_time", format!("{:?}", self.start_time.elapsed())); + span.record("tokenization_time", format!("{:?}", self.tokenization_time)); + span.record("queue_time", format!("{:?}", self.queue_time)); + span.record("inference_time", format!("{:?}", self.inference_time)); + } + + fn record_metrics(&self) { + // Metrics + metrics::histogram!( + "te_request_duration", + self.start_time.elapsed().as_secs_f64() + ); + metrics::histogram!( + "te_request_tokenization_duration", + self.tokenization_time.as_secs_f64() + ); + metrics::histogram!("te_request_queue_duration", self.queue_time.as_secs_f64()); + metrics::histogram!( + "te_request_inference_duration", + self.inference_time.as_secs_f64() + ); + } +} +impl From for HeaderMap { + fn from(value: ResponseMetadata) -> Self { // Headers let mut headers = HeaderMap::new(); headers.insert("x-compute-type", "gpu+optimized".parse().unwrap()); headers.insert( "x-compute-time", - total_time.as_millis().to_string().parse().unwrap(), + value + .start_time + .elapsed() + .as_millis() + .to_string() + .parse() + .unwrap(), ); headers.insert( "x-compute-characters", - compute_chars.to_string().parse().unwrap(), + value.compute_chars.to_string().parse().unwrap(), ); headers.insert( "x-compute-tokens", - compute_tokens.to_string().parse().unwrap(), + value.compute_tokens.to_string().parse().unwrap(), ); headers.insert( "x-total-time", - total_time.as_millis().to_string().parse().unwrap(), + value + .start_time + .elapsed() + .as_millis() + .to_string() + .parse() + .unwrap(), ); headers.insert( "x-tokenization-time", - tokenization_time.as_millis().to_string().parse().unwrap(), + value + .tokenization_time + .as_millis() + .to_string() + .parse() + .unwrap(), ); headers.insert( "x-queue-time", - queue_time.as_millis().to_string().parse().unwrap(), + value.queue_time.as_millis().to_string().parse().unwrap(), ); headers.insert( "x-inference-time", - inference_time.as_millis().to_string().parse().unwrap(), + value + .inference_time + .as_millis() + .to_string() + .parse() + .unwrap(), ); - - // Metrics - metrics::histogram!("te_request_duration", total_time.as_secs_f64()); - metrics::histogram!( - "te_request_tokenization_duration", - tokenization_time.as_secs_f64() - ); - metrics::histogram!("te_request_queue_duration", queue_time.as_secs_f64()); - metrics::histogram!( - "te_request_inference_duration", - inference_time.as_secs_f64() - ); - - tracing::info!("Success"); - - Ok(Response::from_parts( - MetadataMap::from_headers(headers), - response, - Extensions::default(), - )) + headers } } -#[tonic::async_trait] -impl grpc::predict_server::Predict for TextEmbeddingsService { +#[derive(Debug, Clone)] +struct TextEmbeddingsService { + infer: Infer, + info: Info, +} + +impl TextEmbeddingsService { #[instrument( skip_all, - fields(total_time, tokenization_time, queue_time, inference_time,) + fields( + compute_chars, + compute_tokens, + total_time, + tokenization_time, + queue_time, + inference_time, + ) )] - async fn predict( + async fn embed_inner( &self, - request: Request, - ) -> Result, Status> { + request: EmbedRequest, + ) -> Result<(EmbedResponse, ResponseMetadata), Status> { let span = tracing::Span::current(); let start_time = Instant::now(); - let request = request.into_inner(); + metrics::increment_counter!("te_request_count", "method" => "single"); - // Closure for predict - let predict_inner = move |inputs: String, - truncate: bool, - raw_scores: bool, - infer: Infer, - info: Info| async move { - let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?; - let response = infer - .predict(inputs, truncate, raw_scores, permit) - .await - .map_err(ErrorResponse::from)?; + let compute_chars = request.inputs.chars().count(); - let id2label = match &info.model_type { - ModelType::Classifier(classifier) => &classifier.id2label, - ModelType::Reranker(classifier) => &classifier.id2label, - _ => panic!(), - }; + let permit = self + .infer + .try_acquire_permit() + .map_err(ErrorResponse::from)?; + let response = self + .infer + .embed(request.inputs, request.truncate, request.normalize, permit) + .await + .map_err(ErrorResponse::from)?; - let mut predictions: Vec = { - // Map score to label - response - .results - .into_iter() - .enumerate() - .map(|(i, s)| Prediction { - score: s, - label: id2label.get(&i.to_string()).unwrap().clone(), - }) - .collect() - }; - // Reverse sort - predictions.sort_by(|x, y| x.score.partial_cmp(&y.score).unwrap()); - predictions.reverse(); + metrics::increment_counter!("te_request_success", "method" => "single"); - Ok::<(usize, Duration, Duration, Duration, Vec), ErrorResponse>(( - response.prompt_tokens, - response.tokenization, - response.queue, - response.inference, - predictions, - )) - }; + let response_metadata = ResponseMetadata::new(compute_chars, start_time, &response); + response_metadata.record_span(&span); + response_metadata.record_metrics(); + + tracing::info!("Success"); - let ( + Ok(( + EmbedResponse { + embeddings: response.results, + }, + response_metadata, + )) + } + + #[instrument( + skip_all, + fields( compute_chars, compute_tokens, + total_time, tokenization_time, queue_time, inference_time, - predictions, - ) = { - metrics::increment_counter!("te_request_count", "method" => "single"); + ) + )] + async fn predict_inner( + &self, + request: PredictRequest, + ) -> Result<(PredictResponse, ResponseMetadata), Status> { + let span = tracing::Span::current(); + let start_time = Instant::now(); - let compute_chars = request.inputs.chars().count(); - let (prompt_tokens, tokenization, queue, inference, predictions) = predict_inner( - request.inputs, - request.truncate, - request.raw_scores, - self.infer.clone(), - self.info.clone(), - ) - .await?; - - metrics::increment_counter!("te_request_success", "method" => "single"); - - ( - compute_chars, - prompt_tokens, - tokenization, - queue, - inference, - predictions, - ) + metrics::increment_counter!("te_request_count", "method" => "single"); + + let compute_chars = request.inputs.chars().count(); + + let permit = self + .infer + .try_acquire_permit() + .map_err(ErrorResponse::from)?; + let response = self + .infer + .predict(request.inputs, request.truncate, request.raw_scores, permit) + .await + .map_err(ErrorResponse::from)?; + + let id2label = match &self.info.model_type { + ModelType::Classifier(classifier) => &classifier.id2label, + ModelType::Reranker(classifier) => &classifier.id2label, + _ => panic!(), }; - let total_time = start_time.elapsed(); + let response_metadata = ResponseMetadata::new(compute_chars, start_time, &response); - // Tracing metadata - span.record("total_time", format!("{total_time:?}")); - span.record("tokenization_time", format!("{tokenization_time:?}")); - span.record("queue_time", format!("{queue_time:?}")); - span.record("inference_time", format!("{inference_time:?}")); + let mut predictions: Vec = { + // Map score to label + response + .results + .into_iter() + .enumerate() + .map(|(i, s)| Prediction { + score: s, + label: id2label.get(&i.to_string()).unwrap().clone(), + }) + .collect() + }; + // Reverse sort + predictions.sort_by(|x, y| x.score.partial_cmp(&y.score).unwrap()); + predictions.reverse(); - // Headers - let mut headers = HeaderMap::new(); - headers.insert("x-compute-type", "gpu+optimized".parse().unwrap()); - headers.insert( - "x-compute-time", - total_time.as_millis().to_string().parse().unwrap(), - ); - headers.insert( - "x-compute-characters", - compute_chars.to_string().parse().unwrap(), - ); - headers.insert( - "x-compute-tokens", - compute_tokens.to_string().parse().unwrap(), - ); - headers.insert( - "x-total-time", - total_time.as_millis().to_string().parse().unwrap(), - ); - headers.insert( - "x-tokenization-time", - tokenization_time.as_millis().to_string().parse().unwrap(), - ); - headers.insert( - "x-queue-time", - queue_time.as_millis().to_string().parse().unwrap(), - ); - headers.insert( - "x-inference-time", - inference_time.as_millis().to_string().parse().unwrap(), - ); + metrics::increment_counter!("te_request_success", "method" => "single"); - // Metrics - metrics::histogram!("te_request_duration", total_time.as_secs_f64()); - metrics::histogram!( - "te_request_tokenization_duration", - tokenization_time.as_secs_f64() - ); - metrics::histogram!("te_request_queue_duration", queue_time.as_secs_f64()); - metrics::histogram!( - "te_request_inference_duration", - inference_time.as_secs_f64() - ); + response_metadata.record_span(&span); + response_metadata.record_metrics(); tracing::info!("Success"); - Ok(Response::from_parts( - MetadataMap::from_headers(headers), - PredictResponse { predictions }, - Extensions::default(), - )) + Ok((PredictResponse { predictions }, response_metadata)) } -} -#[tonic::async_trait] -impl grpc::rerank_server::Rerank for TextEmbeddingsService { #[instrument( skip_all, - fields(total_time, tokenization_time, queue_time, inference_time,) + fields( + compute_chars, + compute_tokens, + total_time, + tokenization_time, + queue_time, + inference_time, + ) )] - async fn rerank( + async fn rerank_inner( &self, - request: Request, - ) -> Result, Status> { + request: RerankRequest, + ) -> Result<(RerankResponse, ResponseMetadata), Status> { let span = tracing::Span::current(); let start_time = Instant::now(); - let request = request.into_inner(); - match &self.info.model_type { ModelType::Classifier(_) => { metrics::increment_counter!("te_request_failure", "err" => "model_type"); @@ -369,147 +303,170 @@ impl grpc::rerank_server::Rerank for TextEmbeddingsService { )) }; - let ( - compute_chars, - compute_tokens, - tokenization_time, - queue_time, - inference_time, - response, - ) = { - metrics::increment_counter!("te_request_count", "method" => "batch"); - - let batch_size = request.texts.len(); - if batch_size > self.info.max_client_batch_size { - let message = format!( - "batch size {batch_size} > maximum allowed batch size {}", - self.info.max_client_batch_size - ); - tracing::error!("{message}"); - let err = ErrorResponse { - error: message, - error_type: ErrorType::Validation, - }; - metrics::increment_counter!("te_request_failure", "err" => "batch_size"); - Err(err)?; - } + metrics::increment_counter!("te_request_count", "method" => "batch"); + + let batch_size = request.texts.len(); + if batch_size > self.info.max_client_batch_size { + let message = format!( + "batch size {batch_size} > maximum allowed batch size {}", + self.info.max_client_batch_size + ); + tracing::error!("{message}"); + let err = ErrorResponse { + error: message, + error_type: ErrorType::Validation, + }; + metrics::increment_counter!("te_request_failure", "err" => "batch_size"); + Err(err)?; + } - let mut futures = Vec::with_capacity(batch_size); - let query_chars = request.query.chars().count(); - let mut compute_chars = query_chars * batch_size; - - for text in &request.texts { - compute_chars += text.chars().count(); - let local_infer = self.infer.clone(); - futures.push(rerank_inner( - request.query.clone(), - text.clone(), - request.truncate, - request.raw_scores, - local_infer, - )) - } - let results = join_all(futures) - .await - .into_iter() - .collect::, ErrorResponse>>( - )?; - - let mut ranks = Vec::with_capacity(batch_size); - let mut total_tokenization_time = 0; - let mut total_queue_time = 0; - let mut total_inference_time = 0; - let mut total_compute_tokens = 0; - - for (index, r) in results.into_iter().enumerate() { - total_compute_tokens += r.0; - total_tokenization_time += r.1.as_nanos() as u64; - total_queue_time += r.2.as_nanos() as u64; - total_inference_time += r.3.as_nanos() as u64; - let text = if request.return_text { - Some(request.texts[index].clone()) - } else { - None - }; - - ranks.push(Rank { - index: index as u32, - text, - score: r.4, - }) - } + let mut futures = Vec::with_capacity(batch_size); + let query_chars = request.query.chars().count(); + let mut total_compute_chars = query_chars * batch_size; + + for text in &request.texts { + total_compute_chars += text.chars().count(); + let local_infer = self.infer.clone(); + futures.push(rerank_inner( + request.query.clone(), + text.clone(), + request.truncate, + request.raw_scores, + local_infer, + )) + } + let results = join_all(futures) + .await + .into_iter() + .collect::, ErrorResponse>>()?; + + let mut ranks = Vec::with_capacity(batch_size); + let mut total_tokenization_time = 0; + let mut total_queue_time = 0; + let mut total_inference_time = 0; + let mut total_compute_tokens = 0; + + for (index, r) in results.into_iter().enumerate() { + total_compute_tokens += r.0; + total_tokenization_time += r.1.as_nanos() as u64; + total_queue_time += r.2.as_nanos() as u64; + total_inference_time += r.3.as_nanos() as u64; + let text = if request.return_text { + Some(request.texts[index].clone()) + } else { + None + }; - // Reverse sort - ranks.sort_by(|x, y| x.score.partial_cmp(&y.score).unwrap()); - ranks.reverse(); + ranks.push(Rank { + index: index as u32, + text, + score: r.4, + }) + } - let batch_size = batch_size as u64; + // Reverse sort + ranks.sort_by(|x, y| x.score.partial_cmp(&y.score).unwrap()); + ranks.reverse(); - metrics::increment_counter!("te_request_success", "method" => "batch"); + let batch_size = batch_size as u64; - ( - compute_chars, - total_compute_tokens, - Duration::from_nanos(total_tokenization_time / batch_size), - Duration::from_nanos(total_queue_time / batch_size), - Duration::from_nanos(total_inference_time / batch_size), - RerankResponse { ranks }, - ) + metrics::increment_counter!("te_request_success", "method" => "batch"); + + let response_metadata = ResponseMetadata { + compute_chars: total_compute_chars, + compute_tokens: total_compute_tokens, + start_time, + tokenization_time: Duration::from_nanos(total_tokenization_time / batch_size), + queue_time: Duration::from_nanos(total_queue_time / batch_size), + inference_time: Duration::from_nanos(total_inference_time / batch_size), }; + response_metadata.record_span(&span); + response_metadata.record_metrics(); - let total_time = start_time.elapsed(); + tracing::info!("Success"); - // Tracing metadata - span.record("total_time", format!("{total_time:?}")); - span.record("tokenization_time", format!("{tokenization_time:?}")); - span.record("queue_time", format!("{queue_time:?}")); - span.record("inference_time", format!("{inference_time:?}")); + Ok((RerankResponse { ranks }, response_metadata)) + } +} - // Headers - let mut headers = HeaderMap::new(); - headers.insert("x-compute-type", "gpu+optimized".parse().unwrap()); - headers.insert( - "x-compute-time", - total_time.as_millis().to_string().parse().unwrap(), - ); - headers.insert( - "x-compute-characters", - compute_chars.to_string().parse().unwrap(), - ); - headers.insert( - "x-compute-tokens", - compute_tokens.to_string().parse().unwrap(), - ); - headers.insert( - "x-total-time", - total_time.as_millis().to_string().parse().unwrap(), - ); - headers.insert( - "x-tokenization-time", - tokenization_time.as_millis().to_string().parse().unwrap(), - ); - headers.insert( - "x-queue-time", - queue_time.as_millis().to_string().parse().unwrap(), - ); - headers.insert( - "x-inference-time", - inference_time.as_millis().to_string().parse().unwrap(), - ); +#[tonic::async_trait] +impl grpc::info_server::Info for TextEmbeddingsService { + async fn info(&self, _request: Request) -> Result, Status> { + let model_type = match self.info.model_type { + ModelType::Classifier(_) => grpc::ModelType::Classifier, + ModelType::Embedding(_) => grpc::ModelType::Embedding, + ModelType::Reranker(_) => grpc::ModelType::Reranker, + }; - // Metrics - metrics::histogram!("te_request_duration", total_time.as_secs_f64()); - metrics::histogram!( - "te_request_tokenization_duration", - tokenization_time.as_secs_f64() - ); - metrics::histogram!("te_request_queue_duration", queue_time.as_secs_f64()); - metrics::histogram!( - "te_request_inference_duration", - inference_time.as_secs_f64() - ); + Ok(Response::new(InfoResponse { + version: self.info.version.to_string(), + sha: self.info.sha.map(|s| s.to_string()), + docker_label: self.info.docker_label.map(|s| s.to_string()), + model_id: self.info.model_id.clone(), + model_sha: self.info.model_sha.clone(), + model_dtype: self.info.model_dtype.clone(), + model_type: model_type.into(), + max_concurrent_requests: self.info.max_concurrent_requests as u32, + max_input_length: self.info.max_input_length as u32, + max_batch_tokens: self.info.max_batch_tokens as u32, + max_batch_requests: self.info.max_batch_requests.map(|v| v as u32), + max_client_batch_size: self.info.max_client_batch_size as u32, + tokenization_workers: self.info.tokenization_workers as u32, + })) + } +} - tracing::info!("Success"); +#[tonic::async_trait] +impl grpc::embed_server::Embed for TextEmbeddingsService { + async fn embed( + &self, + request: Request, + ) -> Result, Status> { + let request = request.into_inner(); + + let (response, metadata) = self.embed_inner(request).await?; + + let headers = HeaderMap::from(metadata); + + Ok(Response::from_parts( + MetadataMap::from_headers(headers), + response, + Extensions::default(), + )) + } +} + +#[tonic::async_trait] +impl grpc::predict_server::Predict for TextEmbeddingsService { + async fn predict( + &self, + request: Request, + ) -> Result, Status> { + let request = request.into_inner(); + + let (response, metadata) = self.predict_inner(request).await?; + + let headers = HeaderMap::from(metadata); + + Ok(Response::from_parts( + MetadataMap::from_headers(headers), + response, + Extensions::default(), + )) + } +} + +#[tonic::async_trait] +impl grpc::rerank_server::Rerank for TextEmbeddingsService { + async fn rerank( + &self, + request: Request, + ) -> Result, Status> { + let request = request.into_inner(); + + let (response, metadata) = self.rerank_inner(request).await?; + + let headers = HeaderMap::from(metadata); Ok(Response::from_parts( MetadataMap::from_headers(headers), From b8e8f202e29e4c1ac686f523e7e3de5ebc6ac6f8 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Mon, 27 Nov 2023 13:00:35 +0100 Subject: [PATCH 11/12] stream --- Cargo.lock | 3 +- load_tests/load_grpc.js | 14 +- load_tests/load_grpc_stream.js | 62 ++++ proto/tei.proto | 28 +- router/Cargo.toml | 5 +- router/src/grpc/server.rs | 552 +++++++++++++++++++++++++++------ router/src/http/server.rs | 25 +- 7 files changed, 576 insertions(+), 113 deletions(-) create mode 100644 load_tests/load_grpc_stream.js diff --git a/Cargo.lock b/Cargo.lock index 108e1f57..10ef7cef 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3190,6 +3190,7 @@ name = "text-embeddings-router" version = "0.5.0" dependencies = [ "anyhow", + "async-stream", "axum", "axum-tracing-opentelemetry", "clap", @@ -3202,7 +3203,6 @@ dependencies = [ "opentelemetry 0.20.0", "opentelemetry-otlp", "prost 0.12.3", - "prost-types 0.12.3", "reqwest", "serde", "serde_json", @@ -3211,6 +3211,7 @@ dependencies = [ "thiserror", "tokenizers", "tokio", + "tokio-stream", "tonic 0.10.2", "tonic-build 0.10.2", "tonic-health", diff --git a/load_tests/load_grpc.js b/load_tests/load_grpc.js index 80abbd96..fcd9b5b5 100644 --- a/load_tests/load_grpc.js +++ b/load_tests/load_grpc.js @@ -1,5 +1,5 @@ import {check} from 'k6'; -import grpc from 'k6/net/grpc'; +import grpc from 'k6/experimental/grpc'; import {Trend} from 'k6/metrics'; const host = __ENV.HOST || '127.0.0.1:3000'; @@ -37,19 +37,21 @@ export const options = { const client = new grpc.Client(); -client.load(['definitions'], '../../proto/text_embeddings.proto'); +client.load([], '../proto/tei.proto'); export default function () { - client.connect(host, { - plaintext: true - }); + if (__ITER == 0) { + client.connect(host, { + plaintext: true + }); + } const payload = { inputs: inputs, truncate: true, }; - const res = client.invoke('text_embeddings.v1.TextEmbeddings/Embed', payload); + const res = client.invoke('tei.v1.Embed/Embed', payload); check(res, { 'status is OK': (r) => r && r.status === grpc.StatusOK, diff --git a/load_tests/load_grpc_stream.js b/load_tests/load_grpc_stream.js new file mode 100644 index 00000000..89b9115c --- /dev/null +++ b/load_tests/load_grpc_stream.js @@ -0,0 +1,62 @@ +import grpc from 'k6/experimental/grpc'; +import {Trend} from 'k6/metrics'; + +const host = __ENV.HOST || '127.0.0.1:3000'; + +const totalTime = new Trend('total_time', true); +const tokenizationTIme = new Trend('tokenization_time', true); +const queueTime = new Trend('queue_time', true); +const inferenceTime = new Trend('inference_time', true); + +export const inputs = 'A path from a point approximately 330 metres east of the most south westerleasterly corner of Unit 4 Foundry Industrial Estate, then proceeding in a generally east-north-east direction for approximately 64 metres to a point approximately 282 metres east-south-east of the most easterly corner of Unit 2 Foundry Industrial Estate, Victoria Street, Widnes and approximately 259 metres east of the most southerly corner of Unit 4 Foundry Industrial Estate, Victoria Street, then proceeding in a generally east-north-east direction for approximately 350 metres to a point approximately 3 metres west-north-west of the most north westerly corner of the boundary fence of the scrap metal yard on the south side of Cornubia Road, Widnes, and approximately 47 metres west-south-west of the stub end of Cornubia Road be diverted to a 3 metre wide path from a point approximately 183 metres east-south-east of the most easterly corner of Unit 5 Foundry Industrial Estate, Victoria Street and approximately 272 metres east of the most north-easterly corner of 26 Ann Street West, Widnes, then proceeding in a generally north easterly direction for approximately 58 metres to a point approximately 216 metres east-south-east of the most easterly corner of Unit 4 Foundry Industrial Estate, Victoria Street and approximately 221 metres east of the most southerly corner of Unit 5 Foundry Industrial Estate, Victoria Street, then proceeding in a generally easterly direction for approximately 45 metres to a point approximately 265 metres east-south-east of the most north-easterly corner of Unit 3 Foundry Industrial Estate, Victoria Street and approximately 265 metres east of the most southerly corner of Unit 5 Foundry Industrial Estate, Victoria Street, then proceeding in a generally east-south-east direction for approximately 102 metres to a point approximately 366 metres east-south-east of the most easterly corner of Unit 3 Foundry Industrial Estate, Victoria Street and approximately 463 metres east of the most north easterly corner of 22 Ann Street West, Widnes, then proceeding in a generally north-north-easterly direction for approximately 19 metres to a point approximately 368 metres east-south-east of the most easterly corner of Unit 3 Foundry Industrial Estate, Victoria Street and approximately 512 metres east of the most south easterly corner of 17 Batherton Close, Widnes then proceeding in a generally east-south, easterly direction for approximately 16 metres to a point approximately 420 metres east-south-east of the most southerly corner of Unit 2 Foundry'; + +export const options = { + scenarios: { + throughput: { + executor: 'shared-iterations', + vus: 1, + iterations: 1, + maxDuration: '2m', + gracefulStop: '1s', + }, + }, +}; + + +const client = new grpc.Client(); + +client.load([], '../proto/tei.proto'); + +export default function () { + if (__ITER == 0) { + client.connect(host, { + plaintext: true + }); + } + + const stream = new grpc.Stream(client, 'tei.v1.Embed/EmbedStream'); + + stream.on('data', (res) => { + totalTime.add(res.metadata.totalTimeNs / 1e6); + tokenizationTIme.add(res.metadata.tokenizationTimeNs / 1e6); + queueTime.add(res.metadata.queueTimeNs / 1e6); + inferenceTime.add(res.metadata.inferenceTimeNs / 1e6); + }); + + stream.on('error', (err) => { + console.log('Stream Error: ' + JSON.stringify(err)); + }); + + const payload = { + inputs: inputs, + truncate: true, + }; + + // send 128 requests + for (let i = 0; i < 64000; i++) { + stream.write(payload); + } + + // close the client stream + stream.end(); +} diff --git a/proto/tei.proto b/proto/tei.proto index 01a6995d..d9131679 100644 --- a/proto/tei.proto +++ b/proto/tei.proto @@ -10,17 +10,17 @@ service Info { service Embed { rpc Embed (EmbedRequest) returns (EmbedResponse); - //rpc EmbedStream (stream EmbedRequest) returns (stream EmbedResponse); + rpc EmbedStream (stream EmbedRequest) returns (stream EmbedResponse); } service Predict { rpc Predict (PredictRequest) returns (PredictResponse); - //rpc PredictStream (stream PredictRequest) returns (stream PredictResponse); + rpc PredictStream (stream PredictRequest) returns (stream PredictResponse); } service Rerank { rpc Rerank (RerankRequest) returns (RerankResponse); - //rpc RerankStream (stream RerankRequest) returns (stream RerankResponse); + rpc RerankStream (stream RerankStreamRequest) returns (RerankResponse); } message InfoRequest {} @@ -47,6 +47,15 @@ message InfoResponse { uint32 tokenization_workers = 13; } +message Metadata { + uint32 compute_chars = 1; + uint32 compute_tokens = 2; + uint64 total_time_ns = 3; + uint64 tokenization_time_ns = 4; + uint64 queue_time_ns = 5; + uint64 inference_time_ns = 6; +} + message EmbedRequest { string inputs = 1; bool truncate = 2; @@ -55,6 +64,7 @@ message EmbedRequest { message EmbedResponse { repeated float embeddings = 1; + Metadata metadata = 2; } message PredictRequest { @@ -70,6 +80,7 @@ message Prediction { message PredictResponse { repeated Prediction predictions = 1; + Metadata metadata = 2; } message RerankRequest { @@ -80,6 +91,16 @@ message RerankRequest { bool return_text = 5; } +message RerankStreamRequest{ + string query = 1; + string text = 2; + bool truncate = 3; + // The server will only consider the first value + bool raw_scores = 4; + // The server will only consider the first value + bool return_text = 5; +} + message Rank { uint32 index = 1; optional string text = 2; @@ -88,4 +109,5 @@ message Rank { message RerankResponse { repeated Rank ranks = 1; + Metadata metadata = 2; } diff --git a/router/Cargo.toml b/router/Cargo.toml index 1716e248..b9e4ecd1 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -46,11 +46,12 @@ utoipa = { version = "4.0.0", features = ["axum_extras"], optional = true } utoipa-swagger-ui = { version = "4.0.0", features = ["axum"], optional = true } # gRPC dependencies +async-stream = { version = "0.3.5", optional = true } prost = { version = "0.12.1", optional = true } -prost-types = { version = "0.12.1", optional = true } tonic = { version = "0.10.2", optional = true } tonic-health = { version = "0.10.2", optional = true } tonic-reflection = { version = "0.10.2", optional = true } +tokio-stream = { version = "0.1.14", optional = true } [build-dependencies] vergen = { version = "8.0.0", features = ["build", "git", "gitcl"] } @@ -59,7 +60,7 @@ tonic-build = { version = "0.10.2", optional = true } [features] default = ["candle", "http"] http = ["dep:axum", "dep:axum-tracing-opentelemetry", "dep:tower-http", "dep:utoipa", "dep:utoipa-swagger-ui"] -grpc = ["metrics-exporter-prometheus/http-listener", "dep:prost", "dep:prost-types", "dep:tonic", "dep:tonic-health", "dep:tonic-reflection", "dep:tonic-build"] +grpc = ["metrics-exporter-prometheus/http-listener", "dep:prost", "dep:tonic", "dep:tonic-health", "dep:tonic-reflection", "dep:tonic-build", "dep:async-stream", "dep:tokio-stream"] mkl = ["text-embeddings-backend/mkl"] mkl-dynamic = ["text-embeddings-backend/mkl-dynamic"] accelerate = ["text-embeddings-backend/accelerate"] diff --git a/router/src/grpc/server.rs b/router/src/grpc/server.rs index 5fdbe4b5..dbc40155 100644 --- a/router/src/grpc/server.rs +++ b/router/src/grpc/server.rs @@ -1,3 +1,4 @@ +use crate::grpc::pb::tei::v1::RerankStreamRequest; use crate::grpc::{ EmbedRequest, EmbedResponse, InfoRequest, InfoResponse, PredictRequest, PredictResponse, Prediction, Rank, RerankRequest, RerankResponse, @@ -8,11 +9,14 @@ use metrics_exporter_prometheus::PrometheusBuilder; use std::net::SocketAddr; use std::time::{Duration, Instant}; use text_embeddings_core::infer::{Infer, InferResponse}; +use tokio::sync::{mpsc, OwnedSemaphorePermit}; +use tokio_stream::wrappers::UnboundedReceiverStream; +use tokio_stream::StreamExt; use tonic::codegen::http::HeaderMap; use tonic::metadata::MetadataMap; use tonic::server::NamedService; use tonic::transport::Server; -use tonic::{Code, Extensions, Request, Response, Status}; +use tonic::{Code, Extensions, Request, Response, Status, Streaming}; use tonic_health::ServingStatus; use tracing::{instrument, Span}; @@ -124,13 +128,39 @@ impl From for HeaderMap { } } +impl From<&ResponseMetadata> for grpc::Metadata { + fn from(value: &ResponseMetadata) -> Self { + Self { + compute_chars: value.compute_chars as u32, + compute_tokens: value.compute_tokens as u32, + total_time_ns: value.start_time.elapsed().as_nanos() as u64, + tokenization_time_ns: value.tokenization_time.as_nanos() as u64, + queue_time_ns: value.queue_time.as_nanos() as u64, + inference_time_ns: value.inference_time.as_nanos() as u64, + } + } +} + #[derive(Debug, Clone)] struct TextEmbeddingsService { infer: Infer, info: Info, + max_parallel_stream_requests: usize, } impl TextEmbeddingsService { + fn new(infer: Infer, info: Info) -> Self { + let max_parallel_stream_requests = std::env::var("GRPC_MAX_PARALLEL_STREAM_REQUESTS") + .ok() + .and_then(|s| s.parse::().ok()) + .unwrap_or(1024); + Self { + infer, + info, + max_parallel_stream_requests, + } + } + #[instrument( skip_all, fields( @@ -145,26 +175,18 @@ impl TextEmbeddingsService { async fn embed_inner( &self, request: EmbedRequest, + permit: OwnedSemaphorePermit, ) -> Result<(EmbedResponse, ResponseMetadata), Status> { - let span = tracing::Span::current(); + let span = Span::current(); let start_time = Instant::now(); - metrics::increment_counter!("te_request_count", "method" => "single"); - let compute_chars = request.inputs.chars().count(); - - let permit = self - .infer - .try_acquire_permit() - .map_err(ErrorResponse::from)?; let response = self .infer .embed(request.inputs, request.truncate, request.normalize, permit) .await .map_err(ErrorResponse::from)?; - metrics::increment_counter!("te_request_success", "method" => "single"); - let response_metadata = ResponseMetadata::new(compute_chars, start_time, &response); response_metadata.record_span(&span); response_metadata.record_metrics(); @@ -174,6 +196,7 @@ impl TextEmbeddingsService { Ok(( EmbedResponse { embeddings: response.results, + metadata: Some(grpc::Metadata::from(&response_metadata)), }, response_metadata, )) @@ -193,18 +216,12 @@ impl TextEmbeddingsService { async fn predict_inner( &self, request: PredictRequest, + permit: OwnedSemaphorePermit, ) -> Result<(PredictResponse, ResponseMetadata), Status> { - let span = tracing::Span::current(); + let span = Span::current(); let start_time = Instant::now(); - metrics::increment_counter!("te_request_count", "method" => "single"); - let compute_chars = request.inputs.chars().count(); - - let permit = self - .infer - .try_acquire_permit() - .map_err(ErrorResponse::from)?; let response = self .infer .predict(request.inputs, request.truncate, request.raw_scores, permit) @@ -235,16 +252,221 @@ impl TextEmbeddingsService { predictions.sort_by(|x, y| x.score.partial_cmp(&y.score).unwrap()); predictions.reverse(); - metrics::increment_counter!("te_request_success", "method" => "single"); - response_metadata.record_span(&span); response_metadata.record_metrics(); tracing::info!("Success"); - Ok((PredictResponse { predictions }, response_metadata)) + Ok(( + PredictResponse { + predictions, + metadata: Some(grpc::Metadata::from(&response_metadata)), + }, + response_metadata, + )) } +} + +#[tonic::async_trait] +impl grpc::info_server::Info for TextEmbeddingsService { + async fn info(&self, _request: Request) -> Result, Status> { + let model_type = match self.info.model_type { + ModelType::Classifier(_) => grpc::ModelType::Classifier, + ModelType::Embedding(_) => grpc::ModelType::Embedding, + ModelType::Reranker(_) => grpc::ModelType::Reranker, + }; + + Ok(Response::new(InfoResponse { + version: self.info.version.to_string(), + sha: self.info.sha.map(|s| s.to_string()), + docker_label: self.info.docker_label.map(|s| s.to_string()), + model_id: self.info.model_id.clone(), + model_sha: self.info.model_sha.clone(), + model_dtype: self.info.model_dtype.clone(), + model_type: model_type.into(), + max_concurrent_requests: self.info.max_concurrent_requests as u32, + max_input_length: self.info.max_input_length as u32, + max_batch_tokens: self.info.max_batch_tokens as u32, + max_batch_requests: self.info.max_batch_requests.map(|v| v as u32), + max_client_batch_size: self.info.max_client_batch_size as u32, + tokenization_workers: self.info.tokenization_workers as u32, + })) + } +} + +#[tonic::async_trait] +impl grpc::embed_server::Embed for TextEmbeddingsService { + #[instrument(skip_all)] + async fn embed( + &self, + request: Request, + ) -> Result, Status> { + metrics::increment_counter!("te_request_count", "method" => "single"); + + let permit = self + .infer + .try_acquire_permit() + .map_err(ErrorResponse::from)?; + + let request = request.into_inner(); + let (response, metadata) = self.embed_inner(request, permit).await?; + let headers = HeaderMap::from(metadata); + + metrics::increment_counter!("te_request_success", "method" => "single"); + Ok(Response::from_parts( + MetadataMap::from_headers(headers), + response, + Extensions::default(), + )) + } + + type EmbedStreamStream = UnboundedReceiverStream>; + + #[instrument(skip_all)] + async fn embed_stream( + &self, + request: Request>, + ) -> Result, Status> { + let mut request_stream = request.into_inner(); + + // Create bounded channel to have an upper bound of spawned tasks + // We will have at most `max_parallel_stream_requests` messages from this stream in the queue + let (embed_sender, mut embed_receiver) = mpsc::channel(self.max_parallel_stream_requests); + + // Final channel for the outputs + let (response_sender, response_receiver) = mpsc::unbounded_channel(); + + // Required for the async move below + let local = self.clone(); + + // Background task that uses the bounded channel + tokio::spawn(async move { + while let Some(request) = embed_receiver.recv().await { + // Wait on permit before spawning the task to avoid creating more tasks than needed + let permit = local.infer.acquire_permit().await; + + // Required for the async move below + let task_local = local.clone(); + let task_response_sender = response_sender.clone(); + + // Create async task for this specific input + tokio::spawn(async move { + // Select on closed to cancel work if the stream was closed + tokio::select! { + response = task_local.embed_inner(request, permit) => { + let _ = task_response_sender.send(response.map(|(r, _m)| r)); + } + _ = task_response_sender.closed() => {} + } + }); + } + }); + + // Iterate on input + while let Some(request) = request_stream.next().await { + embed_sender + .send(request?) + .await + .expect("`embed_receiver` was dropped. This is a bug."); + } + // Drop the sender + drop(embed_sender); + + Ok(Response::new(UnboundedReceiverStream::new( + response_receiver, + ))) + } +} + +#[tonic::async_trait] +impl grpc::predict_server::Predict for TextEmbeddingsService { + #[instrument(skip_all)] + async fn predict( + &self, + request: Request, + ) -> Result, Status> { + metrics::increment_counter!("te_request_count", "method" => "single"); + + let permit = self + .infer + .try_acquire_permit() + .map_err(ErrorResponse::from)?; + + let request = request.into_inner(); + let (response, metadata) = self.predict_inner(request, permit).await?; + let headers = HeaderMap::from(metadata); + + metrics::increment_counter!("te_request_success", "method" => "single"); + + Ok(Response::from_parts( + MetadataMap::from_headers(headers), + response, + Extensions::default(), + )) + } + + type PredictStreamStream = UnboundedReceiverStream>; + + #[instrument(skip_all)] + async fn predict_stream( + &self, + request: Request>, + ) -> Result, Status> { + let mut request_stream = request.into_inner(); + + // Create bounded channel to have an upper bound of spawned tasks + // We will have at most `max_parallel_stream_requests` messages from this stream in the queue + let (predict_sender, mut predict_receiver) = + mpsc::channel(self.max_parallel_stream_requests); + + // Final channel for the outputs + let (response_sender, response_receiver) = mpsc::unbounded_channel(); + + // Required for the async move below + let local = self.clone(); + + // Background task that uses the bounded channel + tokio::spawn(async move { + while let Some(request) = predict_receiver.recv().await { + // Wait on permit before spawning the task to avoid creating more tasks than needed + let permit = local.infer.acquire_permit().await; + + // Required for the async move below + let task_local = local.clone(); + let task_response_sender = response_sender.clone(); + + // Create async task for this specific input + tokio::spawn(async move { + // Select on closed to cancel work if the stream was closed + tokio::select! { + response = task_local.predict_inner(request, permit) => { + let _ = task_response_sender.send(response.map(|(r, _m)| r)); + } + _ = task_response_sender.closed() => {} + } + }); + } + }); + + // Iterate on input + while let Some(request) = request_stream.next().await { + predict_sender + .send(request?) + .await + .expect("`predict_receiver` was dropped. This is a bug."); + } + // Drop the sender + drop(predict_sender); + + Ok(Response::new(UnboundedReceiverStream::new( + response_receiver, + ))) + } +} + +#[tonic::async_trait] +impl grpc::rerank_server::Rerank for TextEmbeddingsService { #[instrument( skip_all, fields( @@ -256,13 +478,15 @@ impl TextEmbeddingsService { inference_time, ) )] - async fn rerank_inner( + async fn rerank( &self, - request: RerankRequest, - ) -> Result<(RerankResponse, ResponseMetadata), Status> { - let span = tracing::Span::current(); + request: Request, + ) -> Result, Status> { + let span = Span::current(); let start_time = Instant::now(); + let request = request.into_inner(); + match &self.info.model_type { ModelType::Classifier(_) => { metrics::increment_counter!("te_request_failure", "err" => "model_type"); @@ -285,7 +509,7 @@ impl TextEmbeddingsService { truncate: bool, raw_scores: bool, infer: Infer| async move { - let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?; + let permit = infer.acquire_permit().await; let response = infer .predict((query, text), truncate, raw_scores, permit) @@ -383,94 +607,230 @@ impl TextEmbeddingsService { response_metadata.record_span(&span); response_metadata.record_metrics(); - tracing::info!("Success"); - - Ok((RerankResponse { ranks }, response_metadata)) - } -} - -#[tonic::async_trait] -impl grpc::info_server::Info for TextEmbeddingsService { - async fn info(&self, _request: Request) -> Result, Status> { - let model_type = match self.info.model_type { - ModelType::Classifier(_) => grpc::ModelType::Classifier, - ModelType::Embedding(_) => grpc::ModelType::Embedding, - ModelType::Reranker(_) => grpc::ModelType::Reranker, + let message = RerankResponse { + ranks, + metadata: Some(grpc::Metadata::from(&response_metadata)), }; - Ok(Response::new(InfoResponse { - version: self.info.version.to_string(), - sha: self.info.sha.map(|s| s.to_string()), - docker_label: self.info.docker_label.map(|s| s.to_string()), - model_id: self.info.model_id.clone(), - model_sha: self.info.model_sha.clone(), - model_dtype: self.info.model_dtype.clone(), - model_type: model_type.into(), - max_concurrent_requests: self.info.max_concurrent_requests as u32, - max_input_length: self.info.max_input_length as u32, - max_batch_tokens: self.info.max_batch_tokens as u32, - max_batch_requests: self.info.max_batch_requests.map(|v| v as u32), - max_client_batch_size: self.info.max_client_batch_size as u32, - tokenization_workers: self.info.tokenization_workers as u32, - })) - } -} - -#[tonic::async_trait] -impl grpc::embed_server::Embed for TextEmbeddingsService { - async fn embed( - &self, - request: Request, - ) -> Result, Status> { - let request = request.into_inner(); - - let (response, metadata) = self.embed_inner(request).await?; + let headers = HeaderMap::from(response_metadata); - let headers = HeaderMap::from(metadata); + tracing::info!("Success"); Ok(Response::from_parts( MetadataMap::from_headers(headers), - response, + message, Extensions::default(), )) } -} -#[tonic::async_trait] -impl grpc::predict_server::Predict for TextEmbeddingsService { - async fn predict( + #[instrument( + skip_all, + fields( + compute_chars, + compute_tokens, + total_time, + tokenization_time, + queue_time, + inference_time, + ) + )] + async fn rerank_stream( &self, - request: Request, - ) -> Result, Status> { - let request = request.into_inner(); + request: Request>, + ) -> Result, Status> { + let span = Span::current(); + let start_time = Instant::now(); - let (response, metadata) = self.predict_inner(request).await?; + // Check model type + match &self.info.model_type { + ModelType::Classifier(_) => { + metrics::increment_counter!("te_request_failure", "err" => "model_type"); + let message = "model is not a re-ranker model".to_string(); + tracing::error!("{message}"); + Err(Status::new(Code::FailedPrecondition, message)) + } + ModelType::Reranker(_) => Ok(()), + ModelType::Embedding(_) => { + metrics::increment_counter!("te_request_failure", "err" => "model_type"); + let message = "model is not a classifier model".to_string(); + tracing::error!("{message}"); + Err(Status::new(Code::FailedPrecondition, message)) + } + }?; - let headers = HeaderMap::from(metadata); + // Closure for rerank + let rerank_inner = move |index: usize, + query: String, + text: String, + truncate: bool, + raw_scores: bool, + infer: Infer, + permit: OwnedSemaphorePermit| async move { + let response = infer + .predict((query, text.clone()), truncate, raw_scores, permit) + .await + .map_err(ErrorResponse::from)?; - Ok(Response::from_parts( - MetadataMap::from_headers(headers), - response, - Extensions::default(), - )) - } -} + let score = response.results[0]; -#[tonic::async_trait] -impl grpc::rerank_server::Rerank for TextEmbeddingsService { - async fn rerank( - &self, - request: Request, - ) -> Result, Status> { - let request = request.into_inner(); + Ok::<(usize, usize, Duration, Duration, Duration, f32, String), ErrorResponse>(( + index, + response.prompt_tokens, + response.tokenization, + response.queue, + response.inference, + score, + text, + )) + }; - let (response, metadata) = self.rerank_inner(request).await?; + metrics::increment_counter!("te_request_count", "method" => "batch"); - let headers = HeaderMap::from(metadata); + let mut request_stream = request.into_inner(); + + // Create bounded channel to have an upper bound of spawned tasks + // We will have at most `max_parallel_stream_requests` messages from this stream in the queue + let (rerank_sender, mut rerank_receiver) = mpsc::channel(self.max_parallel_stream_requests); + + // Final channel for the outputs + let (response_sender, mut response_receiver) = mpsc::unbounded_channel(); + + // Required for the async move below + let local_infer = self.infer.clone(); + + // Background task that uses the bounded channel + tokio::spawn(async move { + while let Some((index, query, text, truncate, raw_scores)) = + rerank_receiver.recv().await + { + // Wait on permit before spawning the task to avoid creating more tasks than needed + let permit = local_infer.acquire_permit().await; + + // Required for the async move below + let task_response_sender = response_sender.clone(); + let task_infer = local_infer.clone(); + + // Create async task for this specific input + tokio::spawn(async move { + // Select on closed to cancel work if the stream was closed + tokio::select! { + result = rerank_inner(index, query, text, truncate, raw_scores, task_infer, permit) => { + let _ = task_response_sender.send(result); + } + _ = task_response_sender.closed() => {} + } + }); + } + }); + + let mut index = 0; + let mut total_compute_chars = 0; + + // Set by first request + let mut raw_scores = None; + let mut return_text = None; + + while let Some(request) = request_stream.next().await { + let request = request?; + + // Set `raw_scores` and `return_text` using the values in the first request + if raw_scores.is_none() && return_text.is_none() { + raw_scores = Some(request.raw_scores); + return_text = Some(request.return_text); + } + + total_compute_chars += request.query.chars().count(); + total_compute_chars += request.text.chars().count(); + + rerank_sender + .send(( + index, + request.query, + request.text, + request.truncate, + raw_scores.unwrap(), + )) + .await + .expect("`rerank_receiver` was dropped. This is a bug."); + + index += 1; + } + + // Drop the sender to signal to the underlying task that we are done + drop(rerank_sender); + + let batch_size = index; + + let mut ranks = Vec::with_capacity(batch_size); + let mut total_tokenization_time = 0; + let mut total_queue_time = 0; + let mut total_inference_time = 0; + let mut total_compute_tokens = 0; + + // Iterate on result stream + while let Some(r) = response_receiver.recv().await { + let r = r?; + + total_compute_tokens += r.1; + total_tokenization_time += r.2.as_nanos() as u64; + total_queue_time += r.3.as_nanos() as u64; + total_inference_time += r.4.as_nanos() as u64; + let text = if return_text.unwrap() { + Some(r.6) + } else { + None + }; + + ranks.push(Rank { + index: r.0 as u32, + text, + score: r.5, + }) + } + + // Check that the outputs have the correct size + if ranks.len() < batch_size { + let message = "rerank results is missing values".to_string(); + tracing::error!("{message}"); + let err = ErrorResponse { + error: message, + error_type: ErrorType::Backend, + }; + metrics::increment_counter!("te_request_failure", "err" => "missing_values"); + Err(err)?; + } + + // Reverse sort + ranks.sort_by(|x, y| x.score.partial_cmp(&y.score).unwrap()); + ranks.reverse(); + + let batch_size = batch_size as u64; + + metrics::increment_counter!("te_request_success", "method" => "batch"); + + let response_metadata = ResponseMetadata { + compute_chars: total_compute_chars, + compute_tokens: total_compute_tokens, + start_time, + tokenization_time: Duration::from_nanos(total_tokenization_time / batch_size), + queue_time: Duration::from_nanos(total_queue_time / batch_size), + inference_time: Duration::from_nanos(total_inference_time / batch_size), + }; + response_metadata.record_span(&span); + response_metadata.record_metrics(); + + let message = RerankResponse { + ranks, + metadata: Some(grpc::Metadata::from(&response_metadata)), + }; + + let headers = HeaderMap::from(response_metadata); + + tracing::info!("Success"); Ok(Response::from_parts( MetadataMap::from_headers(headers), - response, + message, Extensions::default(), )) } @@ -571,7 +931,7 @@ pub async fn run( .build()?; // Main service - let service = TextEmbeddingsService { infer, info }; + let service = TextEmbeddingsService::new(infer, info); // Create gRPC server tracing::info!("Starting gRPC server: {}", &addr); diff --git a/router/src/http/server.rs b/router/src/http/server.rs index 0d947b8a..74b41350 100644 --- a/router/src/http/server.rs +++ b/router/src/http/server.rs @@ -20,6 +20,7 @@ use std::time::{Duration, Instant}; use text_embeddings_backend::BackendError; use text_embeddings_core::infer::{Infer, InferResponse}; use text_embeddings_core::TextEmbeddingsError; +use tokio::sync::OwnedSemaphorePermit; use tower_http::cors::{AllowOrigin, CorsLayer}; use tracing::instrument; use utoipa::OpenApi; @@ -94,8 +95,13 @@ async fn predict( truncate: bool, raw_scores: bool, infer: Infer, - info: Info| async move { - let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?; + info: Info, + permit: Option| async move { + let permit = match permit { + None => infer.acquire_permit().await, + Some(permit) => permit, + }; + let response = infer .predict(inputs, truncate, raw_scores, permit) .await @@ -138,8 +144,16 @@ async fn predict( metrics::increment_counter!("te_request_count", "method" => "single"); let compute_chars = inputs.count_chars(); - let (prompt_tokens, tokenization, queue, inference, predictions) = - predict_inner(inputs, req.truncate, req.raw_scores, infer.0, info.0).await?; + let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?; + let (prompt_tokens, tokenization, queue, inference, predictions) = predict_inner( + inputs, + req.truncate, + req.raw_scores, + infer.0, + info.0, + Some(permit), + ) + .await?; metrics::increment_counter!("te_request_success", "method" => "single"); @@ -183,6 +197,7 @@ async fn predict( req.raw_scores, local_infer.0, local_info.0, + None, )) } let results = join_all(futures).await.into_iter().collect:: Date: Mon, 27 Nov 2023 15:05:35 +0100 Subject: [PATCH 12/12] preserve order streams --- Cargo.lock | 1 + backends/src/lib.rs | 2 - load_tests/load.js | 4 +- load_tests/load_grpc.js | 4 +- load_tests/load_grpc_stream.js | 7 +- router/Cargo.toml | 1 + router/src/grpc/server.rs | 337 ++++++++------- router/src/http/server.rs | 724 +++++++++++++-------------------- router/src/lib.rs | 121 +++++- 9 files changed, 570 insertions(+), 631 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 10ef7cef..348db95e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3196,6 +3196,7 @@ dependencies = [ "clap", "futures", "hf-hub", + "http 0.2.11", "init-tracing-opentelemetry", "metrics", "metrics-exporter-prometheus", diff --git a/backends/src/lib.rs b/backends/src/lib.rs index c3620938..fd089515 100644 --- a/backends/src/lib.rs +++ b/backends/src/lib.rs @@ -137,8 +137,6 @@ fn init_backend( } else if cfg!(feature = "python") { #[cfg(feature = "python")] { - use std::thread; - return Ok(Box::new( thread::spawn(move || { PythonBackend::new( diff --git a/load_tests/load.js b/load_tests/load.js index a537c389..abaa2090 100644 --- a/load_tests/load.js +++ b/load_tests/load.js @@ -18,8 +18,8 @@ export const options = { scenarios: { // throughput: { // executor: 'shared-iterations', - // vus: 1000, - // iterations: 1000, + // vus: 5000, + // iterations: 5000, // maxDuration: '2m', // gracefulStop: '1s', // }, diff --git a/load_tests/load_grpc.js b/load_tests/load_grpc.js index fcd9b5b5..4c14407a 100644 --- a/load_tests/load_grpc.js +++ b/load_tests/load_grpc.js @@ -18,8 +18,8 @@ export const options = { scenarios: { // throughput: { // executor: 'shared-iterations', - // vus: 1000, - // iterations: 1000, + // vus: 10000, + // iterations: 10000, // maxDuration: '2m', // gracefulStop: '1s', // }, diff --git a/load_tests/load_grpc_stream.js b/load_tests/load_grpc_stream.js index 89b9115c..42ab489f 100644 --- a/load_tests/load_grpc_stream.js +++ b/load_tests/load_grpc_stream.js @@ -1,8 +1,9 @@ import grpc from 'k6/experimental/grpc'; -import {Trend} from 'k6/metrics'; +import {Counter, Trend} from 'k6/metrics'; const host = __ENV.HOST || '127.0.0.1:3000'; +const streamCounter = new Counter('stream_counter'); const totalTime = new Trend('total_time', true); const tokenizationTIme = new Trend('tokenization_time', true); const queueTime = new Trend('queue_time', true); @@ -52,8 +53,8 @@ export default function () { truncate: true, }; - // send 128 requests - for (let i = 0; i < 64000; i++) { + // send 10000 requests + for (let i = 0; i < 10000; i++) { stream.write(payload); } diff --git a/router/Cargo.toml b/router/Cargo.toml index b9e4ecd1..32d130b9 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -22,6 +22,7 @@ clap = { version = "4.1.4", features = ["derive", "env"] } futures = "^0.3" init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] } hf-hub = { version = "0.3.0", features = ["tokio"] } +http = "0.2.9" num_cpus = "1.16.0" metrics = "0.21.0" metrics-exporter-prometheus = { version = "0.12.1", features = [] } diff --git a/router/src/grpc/server.rs b/router/src/grpc/server.rs index dbc40155..b4983ab5 100644 --- a/router/src/grpc/server.rs +++ b/router/src/grpc/server.rs @@ -3,13 +3,14 @@ use crate::grpc::{ EmbedRequest, EmbedResponse, InfoRequest, InfoResponse, PredictRequest, PredictResponse, Prediction, Rank, RerankRequest, RerankResponse, }; +use crate::ResponseMetadata; use crate::{grpc, shutdown, ErrorResponse, ErrorType, Info, ModelType}; use futures::future::join_all; use metrics_exporter_prometheus::PrometheusBuilder; use std::net::SocketAddr; use std::time::{Duration, Instant}; -use text_embeddings_core::infer::{Infer, InferResponse}; -use tokio::sync::{mpsc, OwnedSemaphorePermit}; +use text_embeddings_core::infer::Infer; +use tokio::sync::{mpsc, oneshot, OwnedSemaphorePermit}; use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::StreamExt; use tonic::codegen::http::HeaderMap; @@ -20,114 +21,6 @@ use tonic::{Code, Extensions, Request, Response, Status, Streaming}; use tonic_health::ServingStatus; use tracing::{instrument, Span}; -struct ResponseMetadata { - compute_chars: usize, - compute_tokens: usize, - start_time: Instant, - tokenization_time: Duration, - queue_time: Duration, - inference_time: Duration, -} - -impl ResponseMetadata { - fn new(compute_chars: usize, start_time: Instant, response: &InferResponse) -> Self { - Self { - compute_chars, - compute_tokens: response.prompt_tokens, - start_time, - tokenization_time: response.tokenization, - queue_time: response.queue, - inference_time: response.inference, - } - } - - fn record_span(&self, span: &Span) { - // Tracing metadata - span.record("compute_chars", self.compute_chars); - span.record("compute_tokens", self.compute_tokens); - span.record("total_time", format!("{:?}", self.start_time.elapsed())); - span.record("tokenization_time", format!("{:?}", self.tokenization_time)); - span.record("queue_time", format!("{:?}", self.queue_time)); - span.record("inference_time", format!("{:?}", self.inference_time)); - } - - fn record_metrics(&self) { - // Metrics - metrics::histogram!( - "te_request_duration", - self.start_time.elapsed().as_secs_f64() - ); - metrics::histogram!( - "te_request_tokenization_duration", - self.tokenization_time.as_secs_f64() - ); - metrics::histogram!("te_request_queue_duration", self.queue_time.as_secs_f64()); - metrics::histogram!( - "te_request_inference_duration", - self.inference_time.as_secs_f64() - ); - } -} - -impl From for HeaderMap { - fn from(value: ResponseMetadata) -> Self { - // Headers - let mut headers = HeaderMap::new(); - headers.insert("x-compute-type", "gpu+optimized".parse().unwrap()); - headers.insert( - "x-compute-time", - value - .start_time - .elapsed() - .as_millis() - .to_string() - .parse() - .unwrap(), - ); - headers.insert( - "x-compute-characters", - value.compute_chars.to_string().parse().unwrap(), - ); - headers.insert( - "x-compute-tokens", - value.compute_tokens.to_string().parse().unwrap(), - ); - headers.insert( - "x-total-time", - value - .start_time - .elapsed() - .as_millis() - .to_string() - .parse() - .unwrap(), - ); - headers.insert( - "x-tokenization-time", - value - .tokenization_time - .as_millis() - .to_string() - .parse() - .unwrap(), - ); - headers.insert( - "x-queue-time", - value.queue_time.as_millis().to_string().parse().unwrap(), - ); - headers.insert( - "x-inference-time", - value - .inference_time - .as_millis() - .to_string() - .parse() - .unwrap(), - ); - headers - } -} - impl From<&ResponseMetadata> for grpc::Metadata { fn from(value: &ResponseMetadata) -> Self { Self { @@ -187,7 +80,14 @@ impl TextEmbeddingsService { .await .map_err(ErrorResponse::from)?; - let response_metadata = ResponseMetadata::new(compute_chars, start_time, &response); + let response_metadata = ResponseMetadata::new( + compute_chars, + response.prompt_tokens, + start_time, + response.tokenization, + response.queue, + response.inference, + ); response_metadata.record_span(&span); response_metadata.record_metrics(); @@ -234,7 +134,14 @@ impl TextEmbeddingsService { _ => panic!(), }; - let response_metadata = ResponseMetadata::new(compute_chars, start_time, &response); + let response_metadata = ResponseMetadata::new( + compute_chars, + response.prompt_tokens, + start_time, + response.tokenization, + response.queue, + response.inference, + ); let mut predictions: Vec = { // Map score to label @@ -332,46 +239,77 @@ impl grpc::embed_server::Embed for TextEmbeddingsService { // Create bounded channel to have an upper bound of spawned tasks // We will have at most `max_parallel_stream_requests` messages from this stream in the queue - let (embed_sender, mut embed_receiver) = mpsc::channel(self.max_parallel_stream_requests); - - // Final channel for the outputs - let (response_sender, response_receiver) = mpsc::unbounded_channel(); + let (embed_sender, mut embed_receiver) = mpsc::channel::<( + EmbedRequest, + oneshot::Sender>, + )>(self.max_parallel_stream_requests); // Required for the async move below let local = self.clone(); // Background task that uses the bounded channel tokio::spawn(async move { - while let Some(request) = embed_receiver.recv().await { + while let Some((request, mut sender)) = embed_receiver.recv().await { // Wait on permit before spawning the task to avoid creating more tasks than needed let permit = local.infer.acquire_permit().await; // Required for the async move below let task_local = local.clone(); - let task_response_sender = response_sender.clone(); // Create async task for this specific input tokio::spawn(async move { // Select on closed to cancel work if the stream was closed tokio::select! { response = task_local.embed_inner(request, permit) => { - let _ = task_response_sender.send(response.map(|(r, _m)| r)); + let _ = sender.send(response.map(|(r, _m)| r)); } - _ = task_response_sender.closed() => {} + _ = sender.closed() => {} } }); } }); - // Iterate on input - while let Some(request) = request_stream.next().await { - embed_sender - .send(request?) - .await - .expect("`embed_receiver` was dropped. This is a bug."); - } - // Drop the sender - drop(embed_sender); + // Intermediate channels + // Required to keep the order of the requests + let (intermediate_sender, mut intermediate_receiver) = mpsc::unbounded_channel(); + + tokio::spawn(async move { + // Iterate on input + while let Some(request) = request_stream.next().await { + // Create return channel + let (result_sender, result_receiver) = oneshot::channel(); + // Push to intermediate channel and preserve ordering + intermediate_sender + .send(result_receiver) + .expect("`intermediate_receiver` was dropped. This is a bug."); + + match request { + Ok(request) => embed_sender + .send((request, result_sender)) + .await + .expect("`embed_receiver` was dropped. This is a bug."), + Err(status) => { + // Request is malformed + let _ = result_sender.send(Err(status)); + } + }; + } + }); + + // Final channel for the outputs + let (response_sender, response_receiver) = mpsc::unbounded_channel(); + + tokio::spawn(async move { + while let Some(result_receiver) = intermediate_receiver.recv().await { + // Select on closed to cancel work if the stream was closed + tokio::select! { + response = result_receiver => { + let _ = response_sender.send(response.expect("`result_sender` was dropped. This is a bug.")); + } + _ = response_sender.closed() => {} + } + } + }); Ok(Response::new(UnboundedReceiverStream::new( response_receiver, @@ -417,47 +355,77 @@ impl grpc::predict_server::Predict for TextEmbeddingsService { // Create bounded channel to have an upper bound of spawned tasks // We will have at most `max_parallel_stream_requests` messages from this stream in the queue - let (predict_sender, mut predict_receiver) = - mpsc::channel(self.max_parallel_stream_requests); - - // Final channel for the outputs - let (response_sender, response_receiver) = mpsc::unbounded_channel(); + let (predict_sender, mut predict_receiver) = mpsc::channel::<( + PredictRequest, + oneshot::Sender>, + )>(self.max_parallel_stream_requests); // Required for the async move below let local = self.clone(); // Background task that uses the bounded channel tokio::spawn(async move { - while let Some(request) = predict_receiver.recv().await { + while let Some((request, mut sender)) = predict_receiver.recv().await { // Wait on permit before spawning the task to avoid creating more tasks than needed let permit = local.infer.acquire_permit().await; // Required for the async move below let task_local = local.clone(); - let task_response_sender = response_sender.clone(); // Create async task for this specific input tokio::spawn(async move { // Select on closed to cancel work if the stream was closed tokio::select! { response = task_local.predict_inner(request, permit) => { - let _ = task_response_sender.send(response.map(|(r, _m)| r)); + let _ = sender.send(response.map(|(r, _m)| r)); } - _ = task_response_sender.closed() => {} + _ = sender.closed() => {} } }); } }); - // Iterate on input - while let Some(request) = request_stream.next().await { - predict_sender - .send(request?) - .await - .expect("`predict_receiver` was dropped. This is a bug."); - } - // Drop the sender - drop(predict_sender); + // Intermediate channels + // Required to keep the order of the requests + let (intermediate_sender, mut intermediate_receiver) = mpsc::unbounded_channel(); + + tokio::spawn(async move { + // Iterate on input + while let Some(request) = request_stream.next().await { + // Create return channel + let (result_sender, result_receiver) = oneshot::channel(); + // Push to intermediate channel and preserve ordering + intermediate_sender + .send(result_receiver) + .expect("`intermediate_receiver` was dropped. This is a bug."); + + match request { + Ok(request) => predict_sender + .send((request, result_sender)) + .await + .expect("`predict_receiver` was dropped. This is a bug."), + Err(status) => { + // Request is malformed + let _ = result_sender.send(Err(status)); + } + }; + } + }); + + // Final channel for the outputs + let (response_sender, response_receiver) = mpsc::unbounded_channel(); + + tokio::spawn(async move { + while let Some(result_receiver) = intermediate_receiver.recv().await { + // Select on closed to cancel work if the stream was closed + tokio::select! { + response = result_receiver => { + let _ = response_sender.send(response.expect("`result_sender` was dropped. This is a bug.")); + } + _ = response_sender.closed() => {} + } + } + }); Ok(Response::new(UnboundedReceiverStream::new( response_receiver, @@ -596,14 +564,14 @@ impl grpc::rerank_server::Rerank for TextEmbeddingsService { metrics::increment_counter!("te_request_success", "method" => "batch"); - let response_metadata = ResponseMetadata { - compute_chars: total_compute_chars, - compute_tokens: total_compute_tokens, + let response_metadata = ResponseMetadata::new( + total_compute_chars, + total_compute_tokens, start_time, - tokenization_time: Duration::from_nanos(total_tokenization_time / batch_size), - queue_time: Duration::from_nanos(total_queue_time / batch_size), - inference_time: Duration::from_nanos(total_inference_time / batch_size), - }; + Duration::from_nanos(total_tokenization_time / batch_size), + Duration::from_nanos(total_queue_time / batch_size), + Duration::from_nanos(total_inference_time / batch_size), + ); response_metadata.record_span(&span); response_metadata.record_metrics(); @@ -690,24 +658,25 @@ impl grpc::rerank_server::Rerank for TextEmbeddingsService { // Create bounded channel to have an upper bound of spawned tasks // We will have at most `max_parallel_stream_requests` messages from this stream in the queue - let (rerank_sender, mut rerank_receiver) = mpsc::channel(self.max_parallel_stream_requests); - - // Final channel for the outputs - let (response_sender, mut response_receiver) = mpsc::unbounded_channel(); + let (rerank_sender, mut rerank_receiver) = mpsc::channel::<( + (usize, String, String, bool, bool), + oneshot::Sender< + Result<(usize, usize, Duration, Duration, Duration, f32, String), ErrorResponse>, + >, + )>(self.max_parallel_stream_requests); // Required for the async move below let local_infer = self.infer.clone(); // Background task that uses the bounded channel tokio::spawn(async move { - while let Some((index, query, text, truncate, raw_scores)) = + while let Some(((index, query, text, truncate, raw_scores), mut sender)) = rerank_receiver.recv().await { // Wait on permit before spawning the task to avoid creating more tasks than needed let permit = local_infer.acquire_permit().await; // Required for the async move below - let task_response_sender = response_sender.clone(); let task_infer = local_infer.clone(); // Create async task for this specific input @@ -715,9 +684,9 @@ impl grpc::rerank_server::Rerank for TextEmbeddingsService { // Select on closed to cancel work if the stream was closed tokio::select! { result = rerank_inner(index, query, text, truncate, raw_scores, task_infer, permit) => { - let _ = task_response_sender.send(result); + let _ = sender.send(result); } - _ = task_response_sender.closed() => {} + _ = sender.closed() => {} } }); } @@ -730,9 +699,20 @@ impl grpc::rerank_server::Rerank for TextEmbeddingsService { let mut raw_scores = None; let mut return_text = None; + // Intermediate channels + // Required to keep the order of the requests + let (intermediate_sender, mut intermediate_receiver) = mpsc::unbounded_channel(); + while let Some(request) = request_stream.next().await { let request = request?; + // Create return channel + let (result_sender, result_receiver) = oneshot::channel(); + // Push to intermediate channel and preserve ordering + intermediate_sender + .send(result_receiver) + .expect("`intermediate_receiver` was dropped. This is a bug."); + // Set `raw_scores` and `return_text` using the values in the first request if raw_scores.is_none() && return_text.is_none() { raw_scores = Some(request.raw_scores); @@ -744,11 +724,14 @@ impl grpc::rerank_server::Rerank for TextEmbeddingsService { rerank_sender .send(( - index, - request.query, - request.text, - request.truncate, - raw_scores.unwrap(), + ( + index, + request.query, + request.text, + request.truncate, + raw_scores.unwrap(), + ), + result_sender, )) .await .expect("`rerank_receiver` was dropped. This is a bug."); @@ -768,8 +751,10 @@ impl grpc::rerank_server::Rerank for TextEmbeddingsService { let mut total_compute_tokens = 0; // Iterate on result stream - while let Some(r) = response_receiver.recv().await { - let r = r?; + while let Some(result_receiver) = intermediate_receiver.recv().await { + let r = result_receiver + .await + .expect("`result_sender` was dropped. This is a bug.")?; total_compute_tokens += r.1; total_tokenization_time += r.2.as_nanos() as u64; @@ -808,14 +793,14 @@ impl grpc::rerank_server::Rerank for TextEmbeddingsService { metrics::increment_counter!("te_request_success", "method" => "batch"); - let response_metadata = ResponseMetadata { - compute_chars: total_compute_chars, - compute_tokens: total_compute_tokens, + let response_metadata = ResponseMetadata::new( + total_compute_chars, + total_compute_tokens, start_time, - tokenization_time: Duration::from_nanos(total_tokenization_time / batch_size), - queue_time: Duration::from_nanos(total_queue_time / batch_size), - inference_time: Duration::from_nanos(total_inference_time / batch_size), - }; + Duration::from_nanos(total_tokenization_time / batch_size), + Duration::from_nanos(total_queue_time / batch_size), + Duration::from_nanos(total_inference_time / batch_size), + ); response_metadata.record_span(&span); response_metadata.record_metrics(); diff --git a/router/src/http/server.rs b/router/src/http/server.rs index 74b41350..bd67a5bd 100644 --- a/router/src/http/server.rs +++ b/router/src/http/server.rs @@ -4,7 +4,10 @@ use crate::http::types::{ OpenAICompatRequest, OpenAICompatResponse, OpenAICompatUsage, PredictInput, PredictRequest, PredictResponse, Prediction, Rank, RerankRequest, RerankResponse, Sequence, }; -use crate::{shutdown, ClassifierModel, EmbeddingModel, ErrorResponse, ErrorType, Info, ModelType}; +use crate::{ + shutdown, ClassifierModel, EmbeddingModel, ErrorResponse, ErrorType, Info, ModelType, + ResponseMetadata, +}; use anyhow::Context; use axum::extract::Extension; use axum::http::HeaderValue; @@ -138,152 +141,110 @@ async fn predict( )) }; - let (compute_chars, compute_tokens, tokenization_time, queue_time, inference_time, response) = - match req.inputs { - PredictInput::Single(inputs) => { - metrics::increment_counter!("te_request_count", "method" => "single"); + let (response, metadata) = match req.inputs { + PredictInput::Single(inputs) => { + metrics::increment_counter!("te_request_count", "method" => "single"); - let compute_chars = inputs.count_chars(); - let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?; - let (prompt_tokens, tokenization, queue, inference, predictions) = predict_inner( - inputs, - req.truncate, - req.raw_scores, - infer.0, - info.0, - Some(permit), - ) - .await?; + let compute_chars = inputs.count_chars(); + let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?; + let (prompt_tokens, tokenization, queue, inference, predictions) = predict_inner( + inputs, + req.truncate, + req.raw_scores, + infer.0, + info.0, + Some(permit), + ) + .await?; - metrics::increment_counter!("te_request_success", "method" => "single"); + metrics::increment_counter!("te_request_success", "method" => "single"); - ( + ( + PredictResponse::Single(predictions), + ResponseMetadata::new( compute_chars, prompt_tokens, + start_time, tokenization, queue, inference, - PredictResponse::Single(predictions), - ) + ), + ) + } + PredictInput::Batch(inputs) => { + metrics::increment_counter!("te_request_count", "method" => "batch"); + + let batch_size = inputs.len(); + if batch_size > info.max_client_batch_size { + let message = format!( + "batch size {batch_size} > maximum allowed batch size {}", + info.max_client_batch_size + ); + tracing::error!("{message}"); + let err = ErrorResponse { + error: message, + error_type: ErrorType::Validation, + }; + metrics::increment_counter!("te_request_failure", "err" => "batch_size"); + Err(err)?; } - PredictInput::Batch(inputs) => { - metrics::increment_counter!("te_request_count", "method" => "batch"); - - let batch_size = inputs.len(); - if batch_size > info.max_client_batch_size { - let message = format!( - "batch size {batch_size} > maximum allowed batch size {}", - info.max_client_batch_size - ); - tracing::error!("{message}"); - let err = ErrorResponse { - error: message, - error_type: ErrorType::Validation, - }; - metrics::increment_counter!("te_request_failure", "err" => "batch_size"); - Err(err)?; - } - - let mut futures = Vec::with_capacity(batch_size); - let mut compute_chars = 0; - - for input in inputs { - compute_chars += input.count_chars(); - let local_infer = infer.clone(); - let local_info = info.clone(); - futures.push(predict_inner( - input, - req.truncate, - req.raw_scores, - local_infer.0, - local_info.0, - None, - )) - } - let results = join_all(futures).await.into_iter().collect::)>, - ErrorResponse, - >>()?; - - let mut predictions = Vec::with_capacity(batch_size); - let mut total_tokenization_time = 0; - let mut total_queue_time = 0; - let mut total_inference_time = 0; - let mut total_compute_tokens = 0; - - for r in results { - total_compute_tokens += r.0; - total_tokenization_time += r.1.as_nanos() as u64; - total_queue_time += r.2.as_nanos() as u64; - total_inference_time += r.3.as_nanos() as u64; - predictions.push(r.4); - } - let batch_size = batch_size as u64; - - metrics::increment_counter!("te_request_success", "method" => "batch"); - - ( + + let mut futures = Vec::with_capacity(batch_size); + let mut compute_chars = 0; + + for input in inputs { + compute_chars += input.count_chars(); + let local_infer = infer.clone(); + let local_info = info.clone(); + futures.push(predict_inner( + input, + req.truncate, + req.raw_scores, + local_infer.0, + local_info.0, + None, + )) + } + let results = join_all(futures).await.into_iter().collect::)>, + ErrorResponse, + >>()?; + + let mut predictions = Vec::with_capacity(batch_size); + let mut total_tokenization_time = 0; + let mut total_queue_time = 0; + let mut total_inference_time = 0; + let mut total_compute_tokens = 0; + + for r in results { + total_compute_tokens += r.0; + total_tokenization_time += r.1.as_nanos() as u64; + total_queue_time += r.2.as_nanos() as u64; + total_inference_time += r.3.as_nanos() as u64; + predictions.push(r.4); + } + let batch_size = batch_size as u64; + + metrics::increment_counter!("te_request_success", "method" => "batch"); + + ( + PredictResponse::Batch(predictions), + ResponseMetadata::new( compute_chars, total_compute_tokens, + start_time, Duration::from_nanos(total_tokenization_time / batch_size), Duration::from_nanos(total_queue_time / batch_size), Duration::from_nanos(total_inference_time / batch_size), - PredictResponse::Batch(predictions), - ) - } - }; + ), + ) + } + }; + + metadata.record_span(&span); + metadata.record_metrics(); - let total_time = start_time.elapsed(); - - // Tracing metadata - span.record("total_time", format!("{total_time:?}")); - span.record("tokenization_time", format!("{tokenization_time:?}")); - span.record("queue_time", format!("{queue_time:?}")); - span.record("inference_time", format!("{inference_time:?}")); - - // Headers - let mut headers = HeaderMap::new(); - headers.insert("x-compute-type", "gpu+optimized".parse().unwrap()); - headers.insert( - "x-compute-time", - total_time.as_millis().to_string().parse().unwrap(), - ); - headers.insert( - "x-compute-characters", - compute_chars.to_string().parse().unwrap(), - ); - headers.insert( - "x-compute-tokens", - compute_tokens.to_string().parse().unwrap(), - ); - headers.insert( - "x-total-time", - total_time.as_millis().to_string().parse().unwrap(), - ); - headers.insert( - "x-tokenization-time", - tokenization_time.as_millis().to_string().parse().unwrap(), - ); - headers.insert( - "x-queue-time", - queue_time.as_millis().to_string().parse().unwrap(), - ); - headers.insert( - "x-inference-time", - inference_time.as_millis().to_string().parse().unwrap(), - ); - - // Metrics - metrics::histogram!("te_request_duration", total_time.as_secs_f64()); - metrics::histogram!( - "te_request_tokenization_duration", - tokenization_time.as_secs_f64() - ); - metrics::histogram!("te_request_queue_duration", queue_time.as_secs_f64()); - metrics::histogram!( - "te_request_inference_duration", - inference_time.as_secs_f64() - ); + let headers = HeaderMap::from(metadata); tracing::info!("Success"); @@ -367,7 +328,7 @@ async fn rerank( )) }; - let (compute_chars, compute_tokens, tokenization_time, queue_time, inference_time, response) = { + let (response, metadata) = { metrics::increment_counter!("te_request_count", "method" => "batch"); let batch_size = req.texts.len(); @@ -438,66 +399,22 @@ async fn rerank( metrics::increment_counter!("te_request_success", "method" => "batch"); ( - compute_chars, - total_compute_tokens, - Duration::from_nanos(total_tokenization_time / batch_size), - Duration::from_nanos(total_queue_time / batch_size), - Duration::from_nanos(total_inference_time / batch_size), RerankResponse(ranks), + ResponseMetadata::new( + compute_chars, + total_compute_tokens, + start_time, + Duration::from_nanos(total_tokenization_time / batch_size), + Duration::from_nanos(total_queue_time / batch_size), + Duration::from_nanos(total_inference_time / batch_size), + ), ) }; - let total_time = start_time.elapsed(); - - // Tracing metadata - span.record("total_time", format!("{total_time:?}")); - span.record("tokenization_time", format!("{tokenization_time:?}")); - span.record("queue_time", format!("{queue_time:?}")); - span.record("inference_time", format!("{inference_time:?}")); - - // Headers - let mut headers = HeaderMap::new(); - headers.insert("x-compute-type", "gpu+optimized".parse().unwrap()); - headers.insert( - "x-compute-time", - total_time.as_millis().to_string().parse().unwrap(), - ); - headers.insert( - "x-compute-characters", - compute_chars.to_string().parse().unwrap(), - ); - headers.insert( - "x-compute-tokens", - compute_tokens.to_string().parse().unwrap(), - ); - headers.insert( - "x-total-time", - total_time.as_millis().to_string().parse().unwrap(), - ); - headers.insert( - "x-tokenization-time", - tokenization_time.as_millis().to_string().parse().unwrap(), - ); - headers.insert( - "x-queue-time", - queue_time.as_millis().to_string().parse().unwrap(), - ); - headers.insert( - "x-inference-time", - inference_time.as_millis().to_string().parse().unwrap(), - ); - - // Metrics - metrics::histogram!("te_request_duration", total_time.as_secs_f64()); - metrics::histogram!( - "te_request_tokenization_duration", - tokenization_time.as_secs_f64() - ); - metrics::histogram!("te_request_queue_duration", queue_time.as_secs_f64()); - metrics::histogram!( - "te_request_inference_duration", - inference_time.as_secs_f64() - ); + metadata.record_span(&span); + metadata.record_metrics(); + + let headers = HeaderMap::from(metadata); tracing::info!("Success"); @@ -534,147 +451,105 @@ async fn embed( let span = tracing::Span::current(); let start_time = Instant::now(); - let (compute_chars, compute_tokens, tokenization_time, queue_time, inference_time, response) = - match req.inputs { - Input::Single(input) => { - metrics::increment_counter!("te_request_count", "method" => "single"); + let (response, metadata) = match req.inputs { + Input::Single(input) => { + metrics::increment_counter!("te_request_count", "method" => "single"); - let compute_chars = input.chars().count(); + let compute_chars = input.chars().count(); - let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?; - let response = infer - .embed(input, req.truncate, req.normalize, permit) - .await - .map_err(ErrorResponse::from)?; + let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?; + let response = infer + .embed(input, req.truncate, req.normalize, permit) + .await + .map_err(ErrorResponse::from)?; - metrics::increment_counter!("te_request_success", "method" => "single"); + metrics::increment_counter!("te_request_success", "method" => "single"); - ( + ( + EmbedResponse(vec![response.results]), + ResponseMetadata::new( compute_chars, response.prompt_tokens, + start_time, response.tokenization, response.queue, response.inference, - EmbedResponse(vec![response.results]), - ) + ), + ) + } + Input::Batch(inputs) => { + metrics::increment_counter!("te_request_count", "method" => "batch"); + + let batch_size = inputs.len(); + if batch_size > info.max_client_batch_size { + let message = format!( + "batch size {batch_size} > maximum allowed batch size {}", + info.max_client_batch_size + ); + tracing::error!("{message}"); + let err = ErrorResponse { + error: message, + error_type: ErrorType::Validation, + }; + metrics::increment_counter!("te_request_failure", "err" => "batch_size"); + Err(err)?; } - Input::Batch(inputs) => { - metrics::increment_counter!("te_request_count", "method" => "batch"); - - let batch_size = inputs.len(); - if batch_size > info.max_client_batch_size { - let message = format!( - "batch size {batch_size} > maximum allowed batch size {}", - info.max_client_batch_size - ); - tracing::error!("{message}"); - let err = ErrorResponse { - error: message, - error_type: ErrorType::Validation, - }; - metrics::increment_counter!("te_request_failure", "err" => "batch_size"); - Err(err)?; - } - - let mut futures = Vec::with_capacity(batch_size); - let mut compute_chars = 0; - - for input in inputs { - compute_chars += input.chars().count(); - - let local_infer = infer.clone(); - futures.push(async move { - let permit = local_infer.acquire_permit().await; - local_infer - .embed(input, req.truncate, req.normalize, permit) - .await - }) - } - let results = join_all(futures) - .await - .into_iter() - .collect::, TextEmbeddingsError>>() - .map_err(ErrorResponse::from)?; - - let mut embeddings = Vec::with_capacity(batch_size); - let mut total_tokenization_time = 0; - let mut total_queue_time = 0; - let mut total_inference_time = 0; - let mut total_compute_tokens = 0; - - for r in results { - total_tokenization_time += r.tokenization.as_nanos() as u64; - total_queue_time += r.queue.as_nanos() as u64; - total_inference_time += r.inference.as_nanos() as u64; - total_compute_tokens += r.prompt_tokens; - embeddings.push(r.results); - } - let batch_size = batch_size as u64; - - metrics::increment_counter!("te_request_success", "method" => "batch"); - - ( + + let mut futures = Vec::with_capacity(batch_size); + let mut compute_chars = 0; + + for input in inputs { + compute_chars += input.chars().count(); + + let local_infer = infer.clone(); + futures.push(async move { + let permit = local_infer.acquire_permit().await; + local_infer + .embed(input, req.truncate, req.normalize, permit) + .await + }) + } + let results = join_all(futures) + .await + .into_iter() + .collect::, TextEmbeddingsError>>() + .map_err(ErrorResponse::from)?; + + let mut embeddings = Vec::with_capacity(batch_size); + let mut total_tokenization_time = 0; + let mut total_queue_time = 0; + let mut total_inference_time = 0; + let mut total_compute_tokens = 0; + + for r in results { + total_tokenization_time += r.tokenization.as_nanos() as u64; + total_queue_time += r.queue.as_nanos() as u64; + total_inference_time += r.inference.as_nanos() as u64; + total_compute_tokens += r.prompt_tokens; + embeddings.push(r.results); + } + let batch_size = batch_size as u64; + + metrics::increment_counter!("te_request_success", "method" => "batch"); + + ( + EmbedResponse(embeddings), + ResponseMetadata::new( compute_chars, total_compute_tokens, + start_time, Duration::from_nanos(total_tokenization_time / batch_size), Duration::from_nanos(total_queue_time / batch_size), Duration::from_nanos(total_inference_time / batch_size), - EmbedResponse(embeddings), - ) - } - }; + ), + ) + } + }; + + metadata.record_span(&span); + metadata.record_metrics(); - let total_time = start_time.elapsed(); - - // Tracing metadata - span.record("total_time", format!("{total_time:?}")); - span.record("tokenization_time", format!("{tokenization_time:?}")); - span.record("queue_time", format!("{queue_time:?}")); - span.record("inference_time", format!("{inference_time:?}")); - - // Headers - let mut headers = HeaderMap::new(); - headers.insert("x-compute-type", "gpu+optimized".parse().unwrap()); - headers.insert( - "x-compute-time", - total_time.as_millis().to_string().parse().unwrap(), - ); - headers.insert( - "x-compute-characters", - compute_chars.to_string().parse().unwrap(), - ); - headers.insert( - "x-compute-tokens", - compute_tokens.to_string().parse().unwrap(), - ); - headers.insert( - "x-total-time", - total_time.as_millis().to_string().parse().unwrap(), - ); - headers.insert( - "x-tokenization-time", - tokenization_time.as_millis().to_string().parse().unwrap(), - ); - headers.insert( - "x-queue-time", - queue_time.as_millis().to_string().parse().unwrap(), - ); - headers.insert( - "x-inference-time", - inference_time.as_millis().to_string().parse().unwrap(), - ); - - // Metrics - metrics::histogram!("te_request_duration", total_time.as_secs_f64()); - metrics::histogram!( - "te_request_tokenization_duration", - tokenization_time.as_secs_f64() - ); - metrics::histogram!("te_request_queue_duration", queue_time.as_secs_f64()); - metrics::histogram!( - "te_request_inference_duration", - inference_time.as_secs_f64() - ); + let headers = HeaderMap::from(metadata); tracing::info!("Success"); @@ -712,153 +587,112 @@ async fn openai_embed( let span = tracing::Span::current(); let start_time = Instant::now(); - let (compute_chars, compute_tokens, tokenization_time, queue_time, inference_time, embeddings) = - match req.input { - Input::Single(input) => { - metrics::increment_counter!("te_request_count", "method" => "single"); + let (embeddings, metadata) = match req.input { + Input::Single(input) => { + metrics::increment_counter!("te_request_count", "method" => "single"); - let compute_chars = input.chars().count(); + let compute_chars = input.chars().count(); - let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?; - let response = infer - .embed(input, false, true, permit) - .await - .map_err(ErrorResponse::from)?; + let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?; + let response = infer + .embed(input, false, true, permit) + .await + .map_err(ErrorResponse::from)?; - metrics::increment_counter!("te_request_success", "method" => "single"); + metrics::increment_counter!("te_request_success", "method" => "single"); - ( + ( + vec![OpenAICompatEmbedding { + object: "embedding", + embedding: response.results, + index: 0, + }], + ResponseMetadata::new( compute_chars, response.prompt_tokens, + start_time, response.tokenization, response.queue, response.inference, - vec![OpenAICompatEmbedding { - object: "embedding", - embedding: response.results, - index: 0, - }], - ) + ), + ) + } + Input::Batch(inputs) => { + metrics::increment_counter!("te_request_count", "method" => "batch"); + + let batch_size = inputs.len(); + if batch_size > info.max_client_batch_size { + let message = format!( + "batch size {batch_size} > maximum allowed batch size {}", + info.max_client_batch_size + ); + tracing::error!("{message}"); + let err = ErrorResponse { + error: message, + error_type: ErrorType::Validation, + }; + metrics::increment_counter!("te_request_failure", "err" => "batch_size"); + Err(err)?; + } + + let mut futures = Vec::with_capacity(batch_size); + let mut compute_chars = 0; + + for input in inputs { + compute_chars += input.chars().count(); + + let local_infer = infer.clone(); + futures.push(async move { + let permit = local_infer.acquire_permit().await; + local_infer.embed(input, false, true, permit).await + }) } - Input::Batch(inputs) => { - metrics::increment_counter!("te_request_count", "method" => "batch"); - - let batch_size = inputs.len(); - if batch_size > info.max_client_batch_size { - let message = format!( - "batch size {batch_size} > maximum allowed batch size {}", - info.max_client_batch_size - ); - tracing::error!("{message}"); - let err = ErrorResponse { - error: message, - error_type: ErrorType::Validation, - }; - metrics::increment_counter!("te_request_failure", "err" => "batch_size"); - Err(err)?; - } - - let mut futures = Vec::with_capacity(batch_size); - let mut compute_chars = 0; - - for input in inputs { - compute_chars += input.chars().count(); - - let local_infer = infer.clone(); - futures.push(async move { - let permit = local_infer.acquire_permit().await; - local_infer.embed(input, false, true, permit).await - }) - } - let results = join_all(futures) - .await - .into_iter() - .collect::, TextEmbeddingsError>>() - .map_err(ErrorResponse::from)?; - - let mut embeddings = Vec::with_capacity(batch_size); - let mut total_tokenization_time = 0; - let mut total_queue_time = 0; - let mut total_inference_time = 0; - let mut total_compute_tokens = 0; - - for (i, r) in results.into_iter().enumerate() { - total_tokenization_time += r.tokenization.as_nanos() as u64; - total_queue_time += r.queue.as_nanos() as u64; - total_inference_time += r.inference.as_nanos() as u64; - total_compute_tokens += r.prompt_tokens; - embeddings.push(OpenAICompatEmbedding { - object: "embedding", - embedding: r.results, - index: i, - }); - } - let batch_size = batch_size as u64; - - metrics::increment_counter!("te_request_success", "method" => "batch"); - - ( + let results = join_all(futures) + .await + .into_iter() + .collect::, TextEmbeddingsError>>() + .map_err(ErrorResponse::from)?; + + let mut embeddings = Vec::with_capacity(batch_size); + let mut total_tokenization_time = 0; + let mut total_queue_time = 0; + let mut total_inference_time = 0; + let mut total_compute_tokens = 0; + + for (i, r) in results.into_iter().enumerate() { + total_tokenization_time += r.tokenization.as_nanos() as u64; + total_queue_time += r.queue.as_nanos() as u64; + total_inference_time += r.inference.as_nanos() as u64; + total_compute_tokens += r.prompt_tokens; + embeddings.push(OpenAICompatEmbedding { + object: "embedding", + embedding: r.results, + index: i, + }); + } + let batch_size = batch_size as u64; + + metrics::increment_counter!("te_request_success", "method" => "batch"); + + ( + embeddings, + ResponseMetadata::new( compute_chars, total_compute_tokens, + start_time, Duration::from_nanos(total_tokenization_time / batch_size), Duration::from_nanos(total_queue_time / batch_size), Duration::from_nanos(total_inference_time / batch_size), - embeddings, - ) - } - }; + ), + ) + } + }; + + metadata.record_span(&span); + metadata.record_metrics(); - let total_time = start_time.elapsed(); - - // Tracing metadata - span.record("total_time", format!("{total_time:?}")); - span.record("tokenization_time", format!("{tokenization_time:?}")); - span.record("queue_time", format!("{queue_time:?}")); - span.record("inference_time", format!("{inference_time:?}")); - - // Headers - let mut headers = HeaderMap::new(); - headers.insert("x-compute-type", "gpu+optimized".parse().unwrap()); - headers.insert( - "x-compute-time", - total_time.as_millis().to_string().parse().unwrap(), - ); - headers.insert( - "x-compute-characters", - compute_chars.to_string().parse().unwrap(), - ); - headers.insert( - "x-compute-tokens", - compute_tokens.to_string().parse().unwrap(), - ); - headers.insert( - "x-total-time", - total_time.as_millis().to_string().parse().unwrap(), - ); - headers.insert( - "x-tokenization-time", - tokenization_time.as_millis().to_string().parse().unwrap(), - ); - headers.insert( - "x-queue-time", - queue_time.as_millis().to_string().parse().unwrap(), - ); - headers.insert( - "x-inference-time", - inference_time.as_millis().to_string().parse().unwrap(), - ); - - // Metrics - metrics::histogram!("te_request_duration", total_time.as_secs_f64()); - metrics::histogram!( - "te_request_tokenization_duration", - tokenization_time.as_secs_f64() - ); - metrics::histogram!("te_request_queue_duration", queue_time.as_secs_f64()); - metrics::histogram!( - "te_request_inference_duration", - inference_time.as_secs_f64() - ); + let compute_tokens = metadata.compute_tokens; + let headers = HeaderMap::from(metadata); tracing::info!("Success"); diff --git a/router/src/lib.rs b/router/src/lib.rs index 06772883..d7164707 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1,10 +1,13 @@ -use anyhow::Result; +use ::http::HeaderMap; /// Text Embedding Inference Webserver +use anyhow::Result; use serde::Serialize; use std::collections::HashMap; use std::net::SocketAddr; +use std::time::{Duration, Instant}; use text_embeddings_core::infer::Infer; use text_embeddings_core::TextEmbeddingsError; +use tracing::Span; mod prometheus; @@ -15,6 +18,7 @@ mod http; mod grpc; mod shutdown; +/// Crate entrypoint pub async fn run(infer: Infer, info: Info, addr: SocketAddr) -> Result<()> { let prom_builder = prometheus::prometheus_builer(info.max_input_length)?; @@ -130,3 +134,118 @@ impl From for ErrorResponse { } } } + +struct ResponseMetadata { + compute_chars: usize, + compute_tokens: usize, + start_time: Instant, + tokenization_time: Duration, + queue_time: Duration, + inference_time: Duration, +} + +impl ResponseMetadata { + fn new( + compute_chars: usize, + compute_tokens: usize, + start_time: Instant, + tokenization_time: Duration, + queue_time: Duration, + inference_time: Duration, + ) -> Self { + Self { + compute_chars, + compute_tokens, + start_time, + tokenization_time, + queue_time, + inference_time, + } + } + + fn record_span(&self, span: &Span) { + // Tracing metadata + span.record("compute_chars", self.compute_chars); + span.record("compute_tokens", self.compute_tokens); + span.record("total_time", format!("{:?}", self.start_time.elapsed())); + span.record("tokenization_time", format!("{:?}", self.tokenization_time)); + span.record("queue_time", format!("{:?}", self.queue_time)); + span.record("inference_time", format!("{:?}", self.inference_time)); + } + + fn record_metrics(&self) { + // Metrics + metrics::histogram!( + "te_request_duration", + self.start_time.elapsed().as_secs_f64() + ); + metrics::histogram!( + "te_request_tokenization_duration", + self.tokenization_time.as_secs_f64() + ); + metrics::histogram!("te_request_queue_duration", self.queue_time.as_secs_f64()); + metrics::histogram!( + "te_request_inference_duration", + self.inference_time.as_secs_f64() + ); + } +} + +impl From for HeaderMap { + fn from(value: ResponseMetadata) -> Self { + // Headers + let mut headers = HeaderMap::new(); + headers.insert("x-compute-type", "gpu+optimized".parse().unwrap()); + headers.insert( + "x-compute-time", + value + .start_time + .elapsed() + .as_millis() + .to_string() + .parse() + .unwrap(), + ); + headers.insert( + "x-compute-characters", + value.compute_chars.to_string().parse().unwrap(), + ); + headers.insert( + "x-compute-tokens", + value.compute_tokens.to_string().parse().unwrap(), + ); + headers.insert( + "x-total-time", + value + .start_time + .elapsed() + .as_millis() + .to_string() + .parse() + .unwrap(), + ); + headers.insert( + "x-tokenization-time", + value + .tokenization_time + .as_millis() + .to_string() + .parse() + .unwrap(), + ); + headers.insert( + "x-queue-time", + value.queue_time.as_millis().to_string().parse().unwrap(), + ); + headers.insert( + "x-inference-time", + value + .inference_time + .as_millis() + .to_string() + .parse() + .unwrap(), + ); + headers + } +}