diff --git a/src/llm/anthropic.rs b/src/llm/anthropic.rs index 57d968aa..08315523 100644 --- a/src/llm/anthropic.rs +++ b/src/llm/anthropic.rs @@ -1,6 +1,5 @@ use crate::llm::{ - LlmGenerateRequest, LlmGenerateResponse, LlmGenerationClient, LlmSpec, OutputFormat, - ToJsonSchemaOptions, + LlmClient, LlmGenerateRequest, LlmGenerateResponse, OutputFormat, ToJsonSchemaOptions, }; use anyhow::{Context, Result, bail}; use async_trait::async_trait; @@ -11,19 +10,20 @@ use crate::api_bail; use urlencoding::encode; pub struct Client { - model: String, api_key: String, client: reqwest::Client, } impl Client { - pub async fn new(spec: LlmSpec) -> Result { + pub async fn new(address: Option) -> Result { + if address.is_some() { + api_bail!("Anthropic doesn't support custom API address"); + } let api_key = match std::env::var("ANTHROPIC_API_KEY") { Ok(val) => val, Err(_) => api_bail!("ANTHROPIC_API_KEY environment variable must be set"), }; Ok(Self { - model: spec.model, api_key, client: reqwest::Client::new(), }) @@ -31,7 +31,7 @@ impl Client { } #[async_trait] -impl LlmGenerationClient for Client { +impl LlmClient for Client { async fn generate<'req>( &self, request: LlmGenerateRequest<'req>, @@ -42,7 +42,7 @@ impl LlmGenerationClient for Client { })]; let mut payload = serde_json::json!({ - "model": self.model, + "model": request.model, "messages": messages, "max_tokens": 4096 }); diff --git a/src/llm/gemini.rs b/src/llm/gemini.rs index f48f839d..09c6b8ca 100644 --- a/src/llm/gemini.rs +++ b/src/llm/gemini.rs @@ -1,7 +1,6 @@ use crate::api_bail; use crate::llm::{ - LlmGenerateRequest, LlmGenerateResponse, LlmGenerationClient, LlmSpec, OutputFormat, - ToJsonSchemaOptions, + LlmClient, LlmGenerateRequest, LlmGenerateResponse, OutputFormat, ToJsonSchemaOptions, }; use anyhow::{Context, Result, bail}; use async_trait::async_trait; @@ -9,19 +8,20 @@ use serde_json::Value; use urlencoding::encode; pub struct Client { - model: String, api_key: String, client: reqwest::Client, } impl Client { - pub async fn new(spec: LlmSpec) -> Result { + pub async fn new(address: Option) -> Result { + if address.is_some() { + api_bail!("Gemini doesn't support custom API address"); + } let api_key = match std::env::var("GEMINI_API_KEY") { Ok(val) => val, Err(_) => api_bail!("GEMINI_API_KEY environment variable must be set"), }; Ok(Self { - model: spec.model, api_key, client: reqwest::Client::new(), }) @@ -47,7 +47,7 @@ fn remove_additional_properties(value: &mut Value) { } #[async_trait] -impl LlmGenerationClient for Client { +impl LlmClient for Client { async fn generate<'req>( &self, request: LlmGenerateRequest<'req>, @@ -79,7 +79,7 @@ impl LlmGenerationClient for Client { let api_key = &self.api_key; let url = format!( "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent?key={}", - encode(&self.model), + encode(request.model), encode(api_key) ); diff --git a/src/llm/litellm.rs b/src/llm/litellm.rs index 27648747..85d1b50e 100644 --- a/src/llm/litellm.rs +++ b/src/llm/litellm.rs @@ -4,19 +4,13 @@ use async_openai::config::OpenAIConfig; pub use super::openai::Client; impl Client { - pub async fn new_litellm(spec: super::LlmSpec) -> anyhow::Result { - let address = spec - .address - .clone() - .unwrap_or_else(|| "http://127.0.0.1:4000".to_string()); + pub async fn new_litellm(address: Option) -> anyhow::Result { + let address = address.unwrap_or_else(|| "http://127.0.0.1:4000".to_string()); let api_key = std::env::var("LITELLM_API_KEY").ok(); let mut config = OpenAIConfig::new().with_api_base(address); if let Some(api_key) = api_key { config = config.with_api_key(api_key); } - Ok(Client::from_parts( - OpenAIClient::with_config(config), - spec.model, - )) + Ok(Client::from_parts(OpenAIClient::with_config(config))) } } diff --git a/src/llm/mod.rs b/src/llm/mod.rs index ea4aa58e..32bdd5a1 100644 --- a/src/llm/mod.rs +++ b/src/llm/mod.rs @@ -19,9 +19,9 @@ pub enum LlmApiType { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct LlmSpec { - api_type: LlmApiType, - address: Option, - model: String, + pub api_type: LlmApiType, + pub address: Option, + pub model: String, } #[derive(Debug)] @@ -34,6 +34,7 @@ pub enum OutputFormat<'a> { #[derive(Debug)] pub struct LlmGenerateRequest<'a> { + pub model: &'a str, pub system_prompt: Option>, pub user_prompt: Cow<'a, str>, pub output_format: Option>, @@ -45,7 +46,7 @@ pub struct LlmGenerateResponse { } #[async_trait] -pub trait LlmGenerationClient: Send + Sync { +pub trait LlmClient: Send + Sync { async fn generate<'req>( &self, request: LlmGenerateRequest<'req>, @@ -61,25 +62,23 @@ mod ollama; mod openai; mod openrouter; -pub async fn new_llm_generation_client(spec: LlmSpec) -> Result> { - let client = match spec.api_type { - LlmApiType::Ollama => { - Box::new(ollama::Client::new(spec).await?) as Box - } - LlmApiType::OpenAi => { - Box::new(openai::Client::new(spec).await?) as Box - } - LlmApiType::Gemini => { - Box::new(gemini::Client::new(spec).await?) as Box - } +pub async fn new_llm_generation_client( + api_type: LlmApiType, + address: Option, +) -> Result> { + let client = match api_type { + LlmApiType::Ollama => Box::new(ollama::Client::new(address).await?) as Box, + LlmApiType::OpenAi => Box::new(openai::Client::new(address).await?) as Box, + LlmApiType::Gemini => Box::new(gemini::Client::new(address).await?) as Box, LlmApiType::Anthropic => { - Box::new(anthropic::Client::new(spec).await?) as Box + Box::new(anthropic::Client::new(address).await?) as Box } LlmApiType::LiteLlm => { - Box::new(litellm::Client::new_litellm(spec).await?) as Box + Box::new(litellm::Client::new_litellm(address).await?) as Box + } + LlmApiType::OpenRouter => { + Box::new(openrouter::Client::new_openrouter(address).await?) as Box } - LlmApiType::OpenRouter => Box::new(openrouter::Client::new_openrouter(spec).await?) - as Box, }; Ok(client) } diff --git a/src/llm/ollama.rs b/src/llm/ollama.rs index f2926077..7b79293f 100644 --- a/src/llm/ollama.rs +++ b/src/llm/ollama.rs @@ -1,4 +1,4 @@ -use super::LlmGenerationClient; +use super::LlmClient; use anyhow::Result; use async_trait::async_trait; use schemars::schema::SchemaObject; @@ -6,7 +6,6 @@ use serde::{Deserialize, Serialize}; pub struct Client { generate_url: String, - model: String, reqwest_client: reqwest::Client, } @@ -33,27 +32,26 @@ struct OllamaResponse { const OLLAMA_DEFAULT_ADDRESS: &str = "http://localhost:11434"; impl Client { - pub async fn new(spec: super::LlmSpec) -> Result { - let address = match &spec.address { + pub async fn new(address: Option) -> Result { + let address = match &address { Some(addr) => addr.trim_end_matches('/'), None => OLLAMA_DEFAULT_ADDRESS, }; Ok(Self { generate_url: format!("{}/api/generate", address), - model: spec.model, reqwest_client: reqwest::Client::new(), }) } } #[async_trait] -impl LlmGenerationClient for Client { +impl LlmClient for Client { async fn generate<'req>( &self, request: super::LlmGenerateRequest<'req>, ) -> Result { let req = OllamaRequest { - model: &self.model, + model: request.model, prompt: request.user_prompt.as_ref(), format: request.output_format.as_ref().map( |super::OutputFormat::JsonSchema { schema, .. }| { diff --git a/src/llm/openai.rs b/src/llm/openai.rs index edbfa59f..3011ac5c 100644 --- a/src/llm/openai.rs +++ b/src/llm/openai.rs @@ -1,6 +1,6 @@ use crate::api_bail; -use super::LlmGenerationClient; +use super::LlmClient; use anyhow::Result; use async_openai::{ Client as OpenAIClient, @@ -16,16 +16,15 @@ use async_trait::async_trait; pub struct Client { client: async_openai::Client, - model: String, } impl Client { - pub(crate) fn from_parts(client: async_openai::Client, model: String) -> Self { - Self { client, model } + pub(crate) fn from_parts(client: async_openai::Client) -> Self { + Self { client } } - pub async fn new(spec: super::LlmSpec) -> Result { - if let Some(address) = spec.address { + pub async fn new(address: Option) -> Result { + if let Some(address) = address { api_bail!("OpenAI doesn't support custom API address: {address}"); } // Verify API key is set @@ -35,13 +34,12 @@ impl Client { Ok(Self { // OpenAI client will use OPENAI_API_KEY env variable by default client: OpenAIClient::new(), - model: spec.model, }) } } #[async_trait] -impl LlmGenerationClient for Client { +impl LlmClient for Client { async fn generate<'req>( &self, request: super::LlmGenerateRequest<'req>, @@ -70,7 +68,7 @@ impl LlmGenerationClient for Client { // Create the chat completion request let request = CreateChatCompletionRequest { - model: self.model.clone(), + model: request.model.to_string(), messages, response_format: match request.output_format { Some(super::OutputFormat::JsonSchema { name, schema }) => { diff --git a/src/llm/openrouter.rs b/src/llm/openrouter.rs index cb775788..ecf4d0fa 100644 --- a/src/llm/openrouter.rs +++ b/src/llm/openrouter.rs @@ -4,19 +4,13 @@ use async_openai::config::OpenAIConfig; pub use super::openai::Client; impl Client { - pub async fn new_openrouter(spec: super::LlmSpec) -> anyhow::Result { - let address = spec - .address - .clone() - .unwrap_or_else(|| "https://openrouter.ai/api/v1".to_string()); + pub async fn new_openrouter(address: Option) -> anyhow::Result { + let address = address.unwrap_or_else(|| "https://openrouter.ai/api/v1".to_string()); let api_key = std::env::var("OPENROUTER_API_KEY").ok(); let mut config = OpenAIConfig::new().with_api_base(address); if let Some(api_key) = api_key { config = config.with_api_key(api_key); } - Ok(Client::from_parts( - OpenAIClient::with_config(config), - spec.model, - )) + Ok(Client::from_parts(OpenAIClient::with_config(config))) } } diff --git a/src/ops/functions/extract_by_llm.rs b/src/ops/functions/extract_by_llm.rs index 904bb234..ec9782c8 100644 --- a/src/ops/functions/extract_by_llm.rs +++ b/src/ops/functions/extract_by_llm.rs @@ -1,8 +1,6 @@ use crate::prelude::*; -use crate::llm::{ - LlmGenerateRequest, LlmGenerationClient, LlmSpec, OutputFormat, new_llm_generation_client, -}; +use crate::llm::{LlmClient, LlmGenerateRequest, LlmSpec, OutputFormat, new_llm_generation_client}; use crate::ops::sdk::*; use base::json_schema::build_json_schema; use schemars::schema::SchemaObject; @@ -21,7 +19,8 @@ pub struct Args { struct Executor { args: Args, - client: Box, + client: Box, + model: String, output_json_schema: SchemaObject, system_prompt: String, value_extractor: base::json_schema::ValueExtractor, @@ -50,11 +49,13 @@ Output only the JSON without any additional messages or explanations." impl Executor { async fn new(spec: Spec, args: Args) -> Result { - let client = new_llm_generation_client(spec.llm_spec).await?; + let client = + new_llm_generation_client(spec.llm_spec.api_type, spec.llm_spec.address).await?; let schema_output = build_json_schema(spec.output_type, client.json_schema_options())?; Ok(Self { args, client, + model: spec.llm_spec.model, output_json_schema: schema_output.schema, system_prompt: get_system_prompt(&spec.instruction, schema_output.extra_instructions), value_extractor: schema_output.value_extractor, @@ -75,6 +76,7 @@ impl SimpleFunctionExecutor for Executor { async fn evaluate(&self, input: Vec) -> Result { let text = self.args.text.value(&input)?.as_str()?; let req = LlmGenerateRequest { + model: &self.model, system_prompt: Some(Cow::Borrowed(&self.system_prompt)), user_prompt: Cow::Borrowed(text), output_format: Some(OutputFormat::JsonSchema {