diff --git a/Dockerfile-cuda-all b/Dockerfile-cuda-all index ac2dcdb2..5b72fc65 100644 --- a/Dockerfile-cuda-all +++ b/Dockerfile-cuda-all @@ -33,6 +33,7 @@ FROM base-builder AS builder ARG GIT_SHA ARG DOCKER_LABEL +ARG VERTEX # sccache specific variables ARG ACTIONS_CACHE_URL @@ -45,7 +46,12 @@ COPY --from=planner /usr/src/recipe.json recipe.json FROM builder as builder-75 -RUN CUDA_COMPUTE_CAP=75 cargo chef cook --release --features candle-cuda-turing --no-default-features --recipe-path recipe.json && sccache -s +RUN if [ $VERTEX = "true" ]; \ + then \ + CUDA_COMPUTE_CAP=75 cargo chef cook --release --features google --features candle-cuda-turing --no-default-features --recipe-path recipe.json && sccache -s; \ + else \ + CUDA_COMPUTE_CAP=75 cargo chef cook --release --features candle-cuda-turing --no-default-features --recipe-path recipe.json && sccache -s; \ + fi; COPY backends backends COPY core core @@ -53,11 +59,21 @@ COPY router router COPY Cargo.toml ./ COPY Cargo.lock ./ -RUN CUDA_COMPUTE_CAP=75 cargo build --release --bin text-embeddings-router -F candle-cuda-turing -F http --no-default-features && sccache -s +RUN if [ $VERTEX = "true" ]; \ + then \ + CUDA_COMPUTE_CAP=75 cargo build --release --bin text-embeddings-router -F candle-cuda-turing -F http -F google --no-default-features && sccache -s; \ + else \ + CUDA_COMPUTE_CAP=75 cargo build --release --bin text-embeddings-router -F candle-cuda-turing -F http --no-default-features && sccache -s; \ + fi; FROM builder as builder-80 -RUN CUDA_COMPUTE_CAP=80 cargo chef cook --release --features candle-cuda --no-default-features --recipe-path recipe.json && sccache -s +RUN if [ $VERTEX = "true" ]; \ + then \ + CUDA_COMPUTE_CAP=80 cargo chef cook --release --features google --features candle-cuda-turing --no-default-features --recipe-path recipe.json && sccache -s; \ + else \ + CUDA_COMPUTE_CAP=80 cargo chef cook --release --features candle-cuda-turing --no-default-features --recipe-path recipe.json && sccache -s; \ + fi; COPY backends backends COPY core core @@ -65,11 +81,21 @@ COPY router router COPY Cargo.toml ./ COPY Cargo.lock ./ -RUN CUDA_COMPUTE_CAP=80 cargo build --release --bin text-embeddings-router -F candle-cuda -F http --no-default-features && sccache -s +RUN if [ $VERTEX = "true" ]; \ + then \ + CUDA_COMPUTE_CAP=80 cargo build --release --bin text-embeddings-router -F candle-cuda-turing -F http -F google --no-default-features && sccache -s; \ + else \ + CUDA_COMPUTE_CAP=80 cargo build --release --bin text-embeddings-router -F candle-cuda-turing -F http --no-default-features && sccache -s; \ + fi; FROM builder as builder-90 -RUN CUDA_COMPUTE_CAP=90 cargo chef cook --release --features candle-cuda --no-default-features --recipe-path recipe.json && sccache -s +RUN if [ $VERTEX = "true" ]; \ + then \ + CUDA_COMPUTE_CAP=90 cargo chef cook --release --features google --features candle-cuda-turing --no-default-features --recipe-path recipe.json && sccache -s; \ + else \ + CUDA_COMPUTE_CAP=90 cargo chef cook --release --features candle-cuda-turing --no-default-features --recipe-path recipe.json && sccache -s; \ + fi; COPY backends backends COPY core core @@ -77,7 +103,12 @@ COPY router router COPY Cargo.toml ./ COPY Cargo.lock ./ -RUN CUDA_COMPUTE_CAP=90 cargo build --release --bin text-embeddings-router -F candle-cuda -F http --no-default-features && sccache -s +RUN if [ $VERTEX = "true" ]; \ + then \ + CUDA_COMPUTE_CAP=90 cargo build --release --bin text-embeddings-router -F candle-cuda-turing -F http -F google --no-default-features && sccache -s; \ + else \ + CUDA_COMPUTE_CAP=90 cargo build --release --bin text-embeddings-router -F candle-cuda-turing -F http --no-default-features && sccache -s; \ + fi; FROM nvidia/cuda:12.2.0-runtime-ubuntu22.04 as base diff --git a/docs/openapi.json b/docs/openapi.json index c9377087..867c7197 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -658,6 +658,87 @@ } } } + }, + "/vertex": { + "post": { + "tags": [ + "Text Embeddings Inference" + ], + "summary": "Generate embeddings from a Vertex request", + "description": "Generate embeddings from a Vertex request", + "operationId": "vertex_compatibility", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/VertexRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Results" + }, + "413": { + "description": "Batch size error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Batch size error", + "error_type": "validation" + } + } + } + }, + "422": { + "description": "Tokenization error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Tokenization error", + "error_type": "tokenizer" + } + } + } + }, + "424": { + "description": "Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Inference failed", + "error_type": "backend" + } + } + } + }, + "429": { + "description": "Model is overloaded", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Model is overloaded", + "error_type": "overloaded" + } + } + } + } + } + } } }, "components": { @@ -1338,6 +1419,274 @@ } ] ] + }, + "VertexInstance": { + "oneOf": [ + { + "allOf": [ + { + "$ref": "#/components/schemas/EmbedRequest" + }, + { + "type": "object", + "required": [ + "type" + ], + "properties": { + "type": { + "type": "string", + "enum": [ + "embed" + ] + } + } + } + ] + }, + { + "allOf": [ + { + "$ref": "#/components/schemas/EmbedAllRequest" + }, + { + "type": "object", + "required": [ + "type" + ], + "properties": { + "type": { + "type": "string", + "enum": [ + "embed_all" + ] + } + } + } + ] + }, + { + "allOf": [ + { + "$ref": "#/components/schemas/EmbedSparseRequest" + }, + { + "type": "object", + "required": [ + "type" + ], + "properties": { + "type": { + "type": "string", + "enum": [ + "embed_sparse" + ] + } + } + } + ] + }, + { + "allOf": [ + { + "$ref": "#/components/schemas/PredictRequest" + }, + { + "type": "object", + "required": [ + "type" + ], + "properties": { + "type": { + "type": "string", + "enum": [ + "predict" + ] + } + } + } + ] + }, + { + "allOf": [ + { + "$ref": "#/components/schemas/RerankRequest" + }, + { + "type": "object", + "required": [ + "type" + ], + "properties": { + "type": { + "type": "string", + "enum": [ + "rerank" + ] + } + } + } + ] + }, + { + "allOf": [ + { + "$ref": "#/components/schemas/TokenizeRequest" + }, + { + "type": "object", + "required": [ + "type" + ], + "properties": { + "type": { + "type": "string", + "enum": [ + "tokenize" + ] + } + } + } + ] + } + ], + "discriminator": { + "propertyName": "type" + } + }, + "VertexRequest": { + "type": "object", + "required": [ + "instances" + ], + "properties": { + "instances": { + "type": "array", + "items": { + "$ref": "#/components/schemas/VertexInstance" + } + } + } + }, + "VertexResponse": { + "type": "array", + "items": { + "$ref": "#/components/schemas/VertexResponseInstance" + } + }, + "VertexResponseInstance": { + "oneOf": [ + { + "type": "object", + "required": [ + "type", + "result" + ], + "properties": { + "result": { + "$ref": "#/components/schemas/EmbedResponse" + }, + "type": { + "type": "string", + "enum": [ + "embed" + ] + } + } + }, + { + "type": "object", + "required": [ + "type", + "result" + ], + "properties": { + "result": { + "$ref": "#/components/schemas/EmbedAllResponse" + }, + "type": { + "type": "string", + "enum": [ + "embed_all" + ] + } + } + }, + { + "type": "object", + "required": [ + "type", + "result" + ], + "properties": { + "result": { + "$ref": "#/components/schemas/EmbedSparseResponse" + }, + "type": { + "type": "string", + "enum": [ + "embed_sparse" + ] + } + } + }, + { + "type": "object", + "required": [ + "type", + "result" + ], + "properties": { + "result": { + "$ref": "#/components/schemas/PredictResponse" + }, + "type": { + "type": "string", + "enum": [ + "predict" + ] + } + } + }, + { + "type": "object", + "required": [ + "type", + "result" + ], + "properties": { + "result": { + "$ref": "#/components/schemas/RerankResponse" + }, + "type": { + "type": "string", + "enum": [ + "rerank" + ] + } + } + }, + { + "type": "object", + "required": [ + "type", + "result" + ], + "properties": { + "result": { + "$ref": "#/components/schemas/TokenizeResponse" + }, + "type": { + "type": "string", + "enum": [ + "tokenize" + ] + } + } + } + ], + "discriminator": { + "propertyName": "type" + } } } }, diff --git a/router/src/http/server.rs b/router/src/http/server.rs index 453f75b9..6f4d4eb5 100644 --- a/router/src/http/server.rs +++ b/router/src/http/server.rs @@ -4,7 +4,8 @@ use crate::http::types::{ EmbedSparseResponse, Input, OpenAICompatEmbedding, OpenAICompatErrorResponse, OpenAICompatRequest, OpenAICompatResponse, OpenAICompatUsage, PredictInput, PredictRequest, PredictResponse, Prediction, Rank, RerankRequest, RerankResponse, Sequence, SimpleToken, - SparseValue, TokenizeRequest, TokenizeResponse, VertexRequest, + SparseValue, TokenizeRequest, TokenizeResponse, VertexRequest, VertexResponse, + VertexResponseInstance, }; use crate::{ shutdown, ClassifierModel, EmbeddingModel, ErrorResponse, ErrorType, Info, ModelType, @@ -19,6 +20,7 @@ use axum::routing::{get, post}; use axum::{http, Json, Router}; use axum_tracing_opentelemetry::middleware::OtelAxumLayer; use futures::future::join_all; +use futures::FutureExt; use http::header::AUTHORIZATION; use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle}; use std::net::SocketAddr; @@ -1158,8 +1160,8 @@ tag = "Text Embeddings Inference", path = "/vertex", request_body = VertexRequest, responses( -(status = 200, description = "Embeddings", body = EmbedResponse), -(status = 424, description = "Embedding Error", body = ErrorResponse, +(status = 200, description = "Results"), +(status = 424, description = "Error", body = ErrorResponse, example = json ! ({"error": "Inference failed", "error_type": "backend"})), (status = 429, description = "Model is overloaded", body = ErrorResponse, example = json ! ({"error": "Model is overloaded", "error_type": "overloaded"})), @@ -1174,76 +1176,64 @@ async fn vertex_compatibility( infer: Extension, info: Extension, Json(req): Json, -) -> Result<(HeaderMap, Json), (StatusCode, Json)> { - let span = tracing::Span::current(); - let start_time = Instant::now(); - - let batch_size = req.instances.len(); - if batch_size > info.max_client_batch_size { - let message = format!( - "batch size {batch_size} > maximum allowed batch size {}", - info.max_client_batch_size - ); - tracing::error!("{message}"); - let err = ErrorResponse { - error: message, - error_type: ErrorType::Validation, +) -> Result, (StatusCode, Json)> { + let embed_future = move |infer: Extension, info: Extension, req: EmbedRequest| async move { + let result = embed(infer, info, Json(req)).await?; + Ok(VertexResponseInstance::Embed(result.1 .0)) + }; + let embed_sparse_future = + move |infer: Extension, info: Extension, req: EmbedSparseRequest| async move { + let result = embed_sparse(infer, info, Json(req)).await?; + Ok(VertexResponseInstance::EmbedSparse(result.1 .0)) + }; + let predict_future = + move |infer: Extension, info: Extension, req: PredictRequest| async move { + let result = predict(infer, info, Json(req)).await?; + Ok(VertexResponseInstance::Predict(result.1 .0)) + }; + let rerank_future = + move |infer: Extension, info: Extension, req: RerankRequest| async move { + let result = rerank(infer, info, Json(req)).await?; + Ok(VertexResponseInstance::Rerank(result.1 .0)) }; - metrics::increment_counter!("te_request_failure", "err" => "batch_size"); - Err(err)?; - } - let mut futures = Vec::with_capacity(batch_size); - let mut compute_chars = 0; + let mut futures = Vec::with_capacity(req.instances.len()); + for instance in req.instances { + let local_infer = infer.clone(); + let local_info = info.clone(); - for instance in req.instances.iter() { - let input = instance.inputs.clone(); - compute_chars += input.chars().count(); + // Rerank is the only payload that can me matched safely + if let Ok(instance) = serde_json::from_value::(instance.clone()) { + futures.push(rerank_future(local_infer, local_info, instance).boxed()); + continue; + } - let local_infer = infer.clone(); - futures.push(async move { - let permit = local_infer.acquire_permit().await; - local_infer - .embed_pooled(input, instance.truncate, instance.normalize, permit) - .await - }) + match info.model_type { + ModelType::Classifier(_) | ModelType::Reranker(_) => { + let instance = serde_json::from_value::(instance) + .map_err(ErrorResponse::from)?; + futures.push(predict_future(local_infer, local_info, instance).boxed()); + } + ModelType::Embedding(_) => { + if infer.is_splade() { + let instance = serde_json::from_value::(instance) + .map_err(ErrorResponse::from)?; + futures.push(embed_sparse_future(local_infer, local_info, instance).boxed()); + } else { + let instance = serde_json::from_value::(instance) + .map_err(ErrorResponse::from)?; + futures.push(embed_future(local_infer, local_info, instance).boxed()); + } + } + } } - let results = join_all(futures) + + let predictions = join_all(futures) .await .into_iter() - .collect::, TextEmbeddingsError>>() - .map_err(ErrorResponse::from)?; - - let mut embeddings = Vec::with_capacity(batch_size); - let mut total_tokenization_time = 0; - let mut total_queue_time = 0; - let mut total_inference_time = 0; - let mut total_compute_tokens = 0; - - for r in results { - total_tokenization_time += r.metadata.tokenization.as_nanos() as u64; - total_queue_time += r.metadata.queue.as_nanos() as u64; - total_inference_time += r.metadata.inference.as_nanos() as u64; - total_compute_tokens += r.metadata.prompt_tokens; - embeddings.push(r.results); - } - let batch_size = batch_size as u64; - - let response = EmbedResponse(embeddings); - let metadata = ResponseMetadata::new( - compute_chars, - total_compute_tokens, - start_time, - Duration::from_nanos(total_tokenization_time / batch_size), - Duration::from_nanos(total_queue_time / batch_size), - Duration::from_nanos(total_inference_time / batch_size), - ); + .collect::, (StatusCode, Json)>>()?; - metadata.record_span(&span); - metadata.record_metrics(); - tracing::info!("Success"); - - Ok((HeaderMap::from(metadata), Json(response))) + Ok(Json(VertexResponse { predictions })) } /// Prometheus metrics scrape endpoint @@ -1354,12 +1344,10 @@ pub async fn run( // avoid `mut` if possible #[cfg(feature = "google")] { - use crate::http::types::VertexInstance; - #[derive(OpenApi)] #[openapi( paths(vertex_compatibility), - components(schemas(VertexInstance, VertexRequest)) + components(schemas(VertexRequest, VertexResponse, VertexResponseInstance)) )] struct VertextApiDoc; @@ -1397,43 +1385,42 @@ pub async fn run( // Prometheus metrics route .route("/metrics", get(metrics)); - // Set default routes - app = match &info.model_type { - ModelType::Classifier(_) => { - app.route("/", post(predict)) - // AWS Sagemaker route - .route("/invocations", post(predict)) - } - ModelType::Reranker(_) => { - app.route("/", post(rerank)) - // AWS Sagemaker route - .route("/invocations", post(rerank)) - } - ModelType::Embedding(model) => { - if model.pooling == "splade" { - app.route("/", post(embed_sparse)) - // AWS Sagemaker route - .route("/invocations", post(embed_sparse)) - } else { - app.route("/", post(embed)) - // AWS Sagemaker route - .route("/invocations", post(embed)) - } - } - }; - #[cfg(feature = "google")] { tracing::info!("Built with `google` feature"); - tracing::info!( - "Environment variables `AIP_PREDICT_ROUTE` and `AIP_HEALTH_ROUTE` will be respected." - ); - if let Ok(env_predict_route) = std::env::var("AIP_PREDICT_ROUTE") { - app = app.route(&env_predict_route, post(vertex_compatibility)); - } - if let Ok(env_health_route) = std::env::var("AIP_HEALTH_ROUTE") { - app = app.route(&env_health_route, get(health)); - } + let env_predict_route = std::env::var("AIP_PREDICT_ROUTE") + .context("`AIP_PREDICT_ROUTE` env var must be set for Google Vertex deployments")?; + app = app.route(&env_predict_route, post(vertex_compatibility)); + let env_health_route = std::env::var("AIP_HEALTH_ROUTE") + .context("`AIP_HEALTH_ROUTE` env var must be set for Google Vertex deployments")?; + app = app.route(&env_health_route, get(health)); + } + #[cfg(not(feature = "google"))] + { + // Set default routes + app = match &info.model_type { + ModelType::Classifier(_) => { + app.route("/", post(predict)) + // AWS Sagemaker route + .route("/invocations", post(predict)) + } + ModelType::Reranker(_) => { + app.route("/", post(rerank)) + // AWS Sagemaker route + .route("/invocations", post(rerank)) + } + ModelType::Embedding(model) => { + if model.pooling == "splade" { + app.route("/", post(embed_sparse)) + // AWS Sagemaker route + .route("/invocations", post(embed_sparse)) + } else { + app.route("/", post(embed)) + // AWS Sagemaker route + .route("/invocations", post(embed)) + } + } + }; } app = app @@ -1510,3 +1497,12 @@ impl From for (StatusCode, Json) { (StatusCode::from(&err.error_type), Json(err.into())) } } + +impl From for ErrorResponse { + fn from(err: serde_json::Error) -> Self { + ErrorResponse { + error: err.to_string(), + error_type: ErrorType::Validation, + } + } +} diff --git a/router/src/http/types.rs b/router/src/http/types.rs index a638fb29..30c9e728 100644 --- a/router/src/http/types.rs +++ b/router/src/http/types.rs @@ -382,19 +382,21 @@ pub(crate) struct SimpleToken { #[schema(example = json!([[{"id": 0, "text": "test", "special": false, "start": 0, "stop": 2}]]))] pub(crate) struct TokenizeResponse(pub Vec>); -#[derive(Clone, Deserialize, ToSchema)] -pub(crate) struct VertexInstance { - #[schema(example = "What is Deep Learning?")] - pub inputs: String, - #[serde(default)] - #[schema(default = "false", example = "false")] - pub truncate: bool, - #[serde(default = "default_normalize")] - #[schema(default = "true", example = "true")] - pub normalize: bool, -} - #[derive(Deserialize, ToSchema)] pub(crate) struct VertexRequest { - pub instances: Vec, + pub instances: Vec, +} + +#[derive(Serialize, ToSchema)] +#[serde(untagged)] +pub(crate) enum VertexResponseInstance { + Embed(EmbedResponse), + EmbedSparse(EmbedSparseResponse), + Predict(PredictResponse), + Rerank(RerankResponse), +} + +#[derive(Serialize, ToSchema)] +pub(crate) struct VertexResponse { + pub predictions: Vec, } diff --git a/router/src/lib.rs b/router/src/lib.rs index 17008853..8a5715a6 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -246,7 +246,7 @@ pub async fn run( std::env::var("AIP_HTTP_PORT") .ok() .and_then(|p| p.parse().ok()) - .expect("Invalid or unset AIP_HTTP_PORT") + .context("`AIP_HTTP_PORT` env var must be set for Google Vertex deployments")? } else { port }; @@ -264,6 +264,9 @@ pub async fn run( #[cfg(all(feature = "grpc", feature = "http"))] compile_error!("Features `http` and `grpc` cannot be enabled at the same time."); + #[cfg(all(feature = "grpc", feature = "google"))] + compile_error!("Features `http` and `google` cannot be enabled at the same time."); + #[cfg(not(any(feature = "http", feature = "grpc")))] compile_error!("Either feature `http` or `grpc` must be enabled."); diff --git a/router/tests/common.rs b/router/tests/common.rs index 5e8842f4..29331188 100644 --- a/router/tests/common.rs +++ b/router/tests/common.rs @@ -60,6 +60,8 @@ pub async fn start_server(model_id: String, revision: Option, dtype: DTy 8090, None, None, + 2_000_000, + None, None, None, )