Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 86 additions & 19 deletions src/llm/gemini.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,23 @@ use crate::llm::{
};
use base64::prelude::*;
use google_cloud_aiplatform_v1 as vertexai;
use phf::phf_map;
use serde_json::Value;
use urlencoding::encode;

static DEFAULT_EMBEDDING_DIMENSIONS: phf::Map<&str, u32> = phf_map! {
"gemini-embedding-exp-03-07" => 3072,
"text-embedding-004" => 768,
"embedding-001" => 768,
};
fn get_embedding_dimension(model: &str) -> Option<u32> {
let model = model.to_ascii_lowercase();
if model.starts_with("gemini-embedding-") {
Some(3072)
} else if model.starts_with("text-embedding-") {
Some(768)
} else if model.starts_with("embedding-") {
Some(768)
} else if model.starts_with("text-multilingual-embedding-") {
Some(768)
} else {
None
}
}

pub struct AiStudioClient {
api_key: String,
Expand Down Expand Up @@ -192,7 +200,7 @@ impl LlmEmbeddingClient for AiStudioClient {
}

fn get_default_embedding_dimension(&self, model: &str) -> Option<u32> {
DEFAULT_EMBEDDING_DIMENSIONS.get(model).copied()
get_embedding_dimension(model)
}
}

Expand All @@ -202,12 +210,30 @@ pub struct VertexAiClient {
}

impl VertexAiClient {
pub async fn new(config: super::VertexAiConfig) -> Result<Self> {
pub async fn new(
address: Option<String>,
api_config: Option<super::LlmApiConfig>,
) -> Result<Self> {
if address.is_some() {
api_bail!("VertexAi API address is not supported for VertexAi API type");
}
let Some(super::LlmApiConfig::VertexAi(config)) = api_config else {
api_bail!("VertexAi API config is required for VertexAi API type");
};
let client = vertexai::client::PredictionService::builder()
.build()
.await?;
Ok(Self { client, config })
}

fn get_model_path(&self, model: &str) -> String {
format!(
"projects/{}/locations/{}/publishers/google/models/{}",
self.config.project,
self.config.region.as_deref().unwrap_or("global"),
model
)
}
}

#[async_trait]
Expand Down Expand Up @@ -254,20 +280,10 @@ impl LlmGenerationClient for VertexAiClient {
);
}

// projects/{project_id}/locations/global/publishers/google/models/{MODEL}

let model = format!(
"projects/{}/locations/{}/publishers/google/models/{}",
self.config.project,
self.config.region.as_deref().unwrap_or("global"),
request.model
);

// Build the request
let mut req = self
.client
.generate_content()
.set_model(model)
.set_model(self.get_model_path(request.model))
.set_contents(contents);
if let Some(sys) = system_instruction {
req = req.set_system_instruction(sys);
Expand Down Expand Up @@ -301,3 +317,54 @@ impl LlmGenerationClient for VertexAiClient {
}
}
}

#[async_trait]
impl LlmEmbeddingClient for VertexAiClient {
async fn embed_text<'req>(
&self,
request: super::LlmEmbeddingRequest<'req>,
) -> Result<super::LlmEmbeddingResponse> {
// Create the instances for the request
let mut instance = serde_json::json!({
"content": request.text
});
// Add task type if specified
if let Some(task_type) = &request.task_type {
instance["task_type"] = serde_json::Value::String(task_type.to_string());
}

let instances = vec![instance];

// Prepare the request parameters
let mut parameters = serde_json::json!({});
if let Some(output_dimension) = request.output_dimension {
parameters["outputDimensionality"] = serde_json::Value::Number(output_dimension.into());
}

// Build the prediction request using the raw predict builder
let response = self
.client
.predict()
.set_endpoint(self.get_model_path(request.model))
.set_instances(instances)
.set_parameters(parameters)
.send()
.await?;

// Extract the embedding from the response
let embeddings = response
.predictions
.into_iter()
.next()
.and_then(|mut e| e.get_mut("embeddings").map(|v| v.take()))
.ok_or_else(|| anyhow::anyhow!("No embeddings in response"))?;
let embedding: ContentEmbedding = serde_json::from_value(embeddings)?;
Ok(super::LlmEmbeddingResponse {
embedding: embedding.values,
})
}

fn get_default_embedding_dimension(&self, model: &str) -> Option<u32> {
get_embedding_dimension(model)
}
}
20 changes: 7 additions & 13 deletions src/llm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,16 +119,8 @@ pub async fn new_llm_generation_client(
LlmApiType::Gemini => {
Box::new(gemini::AiStudioClient::new(address)?) as Box<dyn LlmGenerationClient>
}
LlmApiType::VertexAi => {
if address.is_some() {
api_bail!("VertexAi API address is not supported for VertexAi API type");
}
let Some(LlmApiConfig::VertexAi(config)) = api_config else {
api_bail!("VertexAi API config is required for VertexAi API type");
};
let config = config.clone();
Box::new(gemini::VertexAiClient::new(config).await?) as Box<dyn LlmGenerationClient>
}
LlmApiType::VertexAi => Box::new(gemini::VertexAiClient::new(address, api_config).await?)
as Box<dyn LlmGenerationClient>,
LlmApiType::Anthropic => {
Box::new(anthropic::Client::new(address).await?) as Box<dyn LlmGenerationClient>
}
Expand All @@ -147,9 +139,10 @@ pub async fn new_llm_generation_client(
Ok(client)
}

pub fn new_llm_embedding_client(
pub async fn new_llm_embedding_client(
api_type: LlmApiType,
address: Option<String>,
api_config: Option<LlmApiConfig>,
) -> Result<Box<dyn LlmEmbeddingClient>> {
let client = match api_type {
LlmApiType::Gemini => {
Expand All @@ -161,12 +154,13 @@ pub fn new_llm_embedding_client(
LlmApiType::Voyage => {
Box::new(voyage::Client::new(address)?) as Box<dyn LlmEmbeddingClient>
}
LlmApiType::VertexAi => Box::new(gemini::VertexAiClient::new(address, api_config).await?)
as Box<dyn LlmEmbeddingClient>,
LlmApiType::Ollama
| LlmApiType::OpenRouter
| LlmApiType::LiteLlm
| LlmApiType::Vllm
| LlmApiType::Anthropic
| LlmApiType::VertexAi => {
| LlmApiType::Anthropic => {
api_bail!("Embedding is not supported for API type {:?}", api_type)
}
};
Expand Down
10 changes: 8 additions & 2 deletions src/ops/functions/embed_text.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use crate::{
llm::{LlmApiType, LlmEmbeddingClient, LlmEmbeddingRequest, new_llm_embedding_client},
llm::{
LlmApiConfig, LlmApiType, LlmEmbeddingClient, LlmEmbeddingRequest, new_llm_embedding_client,
},
ops::sdk::*,
};

Expand All @@ -8,6 +10,7 @@ struct Spec {
api_type: LlmApiType,
model: String,
address: Option<String>,
api_config: Option<LlmApiConfig>,
output_dimension: Option<u32>,
task_type: Option<String>,
}
Expand Down Expand Up @@ -67,7 +70,9 @@ impl SimpleFunctionFactoryBase for Factory {
_context: &FlowInstanceContext,
) -> Result<(Self::ResolvedArgs, EnrichedValueType)> {
let text = args_resolver.next_arg("text")?;
let client = new_llm_embedding_client(spec.api_type, spec.address.clone())?;
let client =
new_llm_embedding_client(spec.api_type, spec.address.clone(), spec.api_config.clone())
.await?;
let output_dimension = match spec.output_dimension {
Some(output_dimension) => output_dimension,
None => {
Expand Down Expand Up @@ -108,6 +113,7 @@ mod tests {
api_type: LlmApiType::OpenAi,
model: "text-embedding-ada-002".to_string(),
address: None,
api_config: None,
output_dimension: None,
task_type: None,
};
Expand Down
Loading