From 55c2451d77f56e13973310584f9e5bd1844a6880 Mon Sep 17 00:00:00 2001 From: LJ Date: Sun, 13 Jul 2025 12:15:24 -0700 Subject: [PATCH 1/2] feat(vertex): support Vertex AI for embedding --- src/llm/gemini.rs | 105 ++++++++++++++++++++++++++------ src/llm/mod.rs | 20 +++--- src/ops/functions/embed_text.rs | 6 +- 3 files changed, 97 insertions(+), 34 deletions(-) diff --git a/src/llm/gemini.rs b/src/llm/gemini.rs index 1eb86974..9586a249 100644 --- a/src/llm/gemini.rs +++ b/src/llm/gemini.rs @@ -6,15 +6,23 @@ use crate::llm::{ }; use base64::prelude::*; use google_cloud_aiplatform_v1 as vertexai; -use phf::phf_map; use serde_json::Value; use urlencoding::encode; -static DEFAULT_EMBEDDING_DIMENSIONS: phf::Map<&str, u32> = phf_map! { - "gemini-embedding-exp-03-07" => 3072, - "text-embedding-004" => 768, - "embedding-001" => 768, -}; +fn get_embedding_dimension(model: &str) -> Option { + let model = model.to_ascii_lowercase(); + if model.starts_with("gemini-embedding-") { + Some(3072) + } else if model.starts_with("text-embedding-") { + Some(768) + } else if model.starts_with("embedding-") { + Some(768) + } else if model.starts_with("text-multilingual-embedding-") { + Some(768) + } else { + None + } +} pub struct AiStudioClient { api_key: String, @@ -192,7 +200,7 @@ impl LlmEmbeddingClient for AiStudioClient { } fn get_default_embedding_dimension(&self, model: &str) -> Option { - DEFAULT_EMBEDDING_DIMENSIONS.get(model).copied() + get_embedding_dimension(model) } } @@ -202,12 +210,30 @@ pub struct VertexAiClient { } impl VertexAiClient { - pub async fn new(config: super::VertexAiConfig) -> Result { + pub async fn new( + address: Option, + api_config: Option, + ) -> Result { + if address.is_some() { + api_bail!("VertexAi API address is not supported for VertexAi API type"); + } + let Some(super::LlmApiConfig::VertexAi(config)) = api_config else { + api_bail!("VertexAi API config is required for VertexAi API type"); + }; let client = vertexai::client::PredictionService::builder() .build() .await?; Ok(Self { client, config }) } + + fn get_model_path(&self, model: &str) -> String { + format!( + "projects/{}/locations/{}/publishers/google/models/{}", + self.config.project, + self.config.region.as_deref().unwrap_or("global"), + model + ) + } } #[async_trait] @@ -254,20 +280,10 @@ impl LlmGenerationClient for VertexAiClient { ); } - // projects/{project_id}/locations/global/publishers/google/models/{MODEL} - - let model = format!( - "projects/{}/locations/{}/publishers/google/models/{}", - self.config.project, - self.config.region.as_deref().unwrap_or("global"), - request.model - ); - - // Build the request let mut req = self .client .generate_content() - .set_model(model) + .set_model(self.get_model_path(request.model)) .set_contents(contents); if let Some(sys) = system_instruction { req = req.set_system_instruction(sys); @@ -301,3 +317,54 @@ impl LlmGenerationClient for VertexAiClient { } } } + +#[async_trait] +impl LlmEmbeddingClient for VertexAiClient { + async fn embed_text<'req>( + &self, + request: super::LlmEmbeddingRequest<'req>, + ) -> Result { + // Create the instances for the request + let mut instance = serde_json::json!({ + "content": request.text + }); + // Add task type if specified + if let Some(task_type) = &request.task_type { + instance["task_type"] = serde_json::Value::String(task_type.to_string()); + } + + let instances = vec![instance]; + + // Prepare the request parameters + let mut parameters = serde_json::json!({}); + if let Some(output_dimension) = request.output_dimension { + parameters["outputDimensionality"] = serde_json::Value::Number(output_dimension.into()); + } + + // Build the prediction request using the raw predict builder + let response = self + .client + .predict() + .set_endpoint(self.get_model_path(request.model)) + .set_instances(instances) + .set_parameters(parameters) + .send() + .await?; + + // Extract the embedding from the response + let embeddings = response + .predictions + .into_iter() + .next() + .and_then(|mut e| e.get_mut("embeddings").map(|v| v.take())) + .ok_or_else(|| anyhow::anyhow!("No embeddings in response"))?; + let embedding: ContentEmbedding = serde_json::from_value(embeddings)?; + Ok(super::LlmEmbeddingResponse { + embedding: embedding.values, + }) + } + + fn get_default_embedding_dimension(&self, model: &str) -> Option { + get_embedding_dimension(model) + } +} diff --git a/src/llm/mod.rs b/src/llm/mod.rs index a89f9e67..914cb71a 100644 --- a/src/llm/mod.rs +++ b/src/llm/mod.rs @@ -119,16 +119,8 @@ pub async fn new_llm_generation_client( LlmApiType::Gemini => { Box::new(gemini::AiStudioClient::new(address)?) as Box } - LlmApiType::VertexAi => { - if address.is_some() { - api_bail!("VertexAi API address is not supported for VertexAi API type"); - } - let Some(LlmApiConfig::VertexAi(config)) = api_config else { - api_bail!("VertexAi API config is required for VertexAi API type"); - }; - let config = config.clone(); - Box::new(gemini::VertexAiClient::new(config).await?) as Box - } + LlmApiType::VertexAi => Box::new(gemini::VertexAiClient::new(address, api_config).await?) + as Box, LlmApiType::Anthropic => { Box::new(anthropic::Client::new(address).await?) as Box } @@ -147,9 +139,10 @@ pub async fn new_llm_generation_client( Ok(client) } -pub fn new_llm_embedding_client( +pub async fn new_llm_embedding_client( api_type: LlmApiType, address: Option, + api_config: Option, ) -> Result> { let client = match api_type { LlmApiType::Gemini => { @@ -161,12 +154,13 @@ pub fn new_llm_embedding_client( LlmApiType::Voyage => { Box::new(voyage::Client::new(address)?) as Box } + LlmApiType::VertexAi => Box::new(gemini::VertexAiClient::new(address, api_config).await?) + as Box, LlmApiType::Ollama | LlmApiType::OpenRouter | LlmApiType::LiteLlm | LlmApiType::Vllm - | LlmApiType::Anthropic - | LlmApiType::VertexAi => { + | LlmApiType::Anthropic => { api_bail!("Embedding is not supported for API type {:?}", api_type) } }; diff --git a/src/ops/functions/embed_text.rs b/src/ops/functions/embed_text.rs index 8688bb36..f90b7884 100644 --- a/src/ops/functions/embed_text.rs +++ b/src/ops/functions/embed_text.rs @@ -1,5 +1,5 @@ use crate::{ - llm::{LlmApiType, LlmEmbeddingClient, LlmEmbeddingRequest, new_llm_embedding_client}, + llm::{LlmApiConfig, LlmApiType, LlmEmbeddingClient, LlmEmbeddingRequest, new_llm_embedding_client}, ops::sdk::*, }; @@ -8,6 +8,7 @@ struct Spec { api_type: LlmApiType, model: String, address: Option, + api_config: Option, output_dimension: Option, task_type: Option, } @@ -67,7 +68,7 @@ impl SimpleFunctionFactoryBase for Factory { _context: &FlowInstanceContext, ) -> Result<(Self::ResolvedArgs, EnrichedValueType)> { let text = args_resolver.next_arg("text")?; - let client = new_llm_embedding_client(spec.api_type, spec.address.clone())?; + let client = new_llm_embedding_client(spec.api_type, spec.address.clone(), spec.api_config.clone()).await?; let output_dimension = match spec.output_dimension { Some(output_dimension) => output_dimension, None => { @@ -108,6 +109,7 @@ mod tests { api_type: LlmApiType::OpenAi, model: "text-embedding-ada-002".to_string(), address: None, + api_config: None, output_dimension: None, task_type: None, }; From f2b6f72f0f87f59a0876df11930612645c2e98df Mon Sep 17 00:00:00 2001 From: LJ Date: Sun, 13 Jul 2025 12:18:57 -0700 Subject: [PATCH 2/2] style: format fix --- src/ops/functions/embed_text.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/ops/functions/embed_text.rs b/src/ops/functions/embed_text.rs index f90b7884..2857484f 100644 --- a/src/ops/functions/embed_text.rs +++ b/src/ops/functions/embed_text.rs @@ -1,5 +1,7 @@ use crate::{ - llm::{LlmApiConfig, LlmApiType, LlmEmbeddingClient, LlmEmbeddingRequest, new_llm_embedding_client}, + llm::{ + LlmApiConfig, LlmApiType, LlmEmbeddingClient, LlmEmbeddingRequest, new_llm_embedding_client, + }, ops::sdk::*, }; @@ -68,7 +70,9 @@ impl SimpleFunctionFactoryBase for Factory { _context: &FlowInstanceContext, ) -> Result<(Self::ResolvedArgs, EnrichedValueType)> { let text = args_resolver.next_arg("text")?; - let client = new_llm_embedding_client(spec.api_type, spec.address.clone(), spec.api_config.clone()).await?; + let client = + new_llm_embedding_client(spec.api_type, spec.address.clone(), spec.api_config.clone()) + .await?; let output_dimension = match spec.output_dimension { Some(output_dimension) => output_dimension, None => {