Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions src/llm/anthropic.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use crate::llm::{
LlmGenerateRequest, LlmGenerateResponse, LlmGenerationClient, LlmSpec, OutputFormat,
ToJsonSchemaOptions,
LlmClient, LlmGenerateRequest, LlmGenerateResponse, OutputFormat, ToJsonSchemaOptions,
};
use anyhow::{Context, Result, bail};
use async_trait::async_trait;
Expand All @@ -11,27 +10,28 @@ use crate::api_bail;
use urlencoding::encode;

pub struct Client {
model: String,
api_key: String,
client: reqwest::Client,
}

impl Client {
pub async fn new(spec: LlmSpec) -> Result<Self> {
pub async fn new(address: Option<String>) -> Result<Self> {
if address.is_some() {
api_bail!("Anthropic doesn't support custom API address");
}
let api_key = match std::env::var("ANTHROPIC_API_KEY") {
Ok(val) => val,
Err(_) => api_bail!("ANTHROPIC_API_KEY environment variable must be set"),
};
Ok(Self {
model: spec.model,
api_key,
client: reqwest::Client::new(),
})
}
}

#[async_trait]
impl LlmGenerationClient for Client {
impl LlmClient for Client {
async fn generate<'req>(
&self,
request: LlmGenerateRequest<'req>,
Expand All @@ -42,7 +42,7 @@ impl LlmGenerationClient for Client {
})];

let mut payload = serde_json::json!({
"model": self.model,
"model": request.model,
"messages": messages,
"max_tokens": 4096
});
Expand Down
14 changes: 7 additions & 7 deletions src/llm/gemini.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,27 @@
use crate::api_bail;
use crate::llm::{
LlmGenerateRequest, LlmGenerateResponse, LlmGenerationClient, LlmSpec, OutputFormat,
ToJsonSchemaOptions,
LlmClient, LlmGenerateRequest, LlmGenerateResponse, OutputFormat, ToJsonSchemaOptions,
};
use anyhow::{Context, Result, bail};
use async_trait::async_trait;
use serde_json::Value;
use urlencoding::encode;

pub struct Client {
model: String,
api_key: String,
client: reqwest::Client,
}

impl Client {
pub async fn new(spec: LlmSpec) -> Result<Self> {
pub async fn new(address: Option<String>) -> Result<Self> {
if address.is_some() {
api_bail!("Gemini doesn't support custom API address");
}
let api_key = match std::env::var("GEMINI_API_KEY") {
Ok(val) => val,
Err(_) => api_bail!("GEMINI_API_KEY environment variable must be set"),
};
Ok(Self {
model: spec.model,
api_key,
client: reqwest::Client::new(),
})
Expand All @@ -47,7 +47,7 @@ fn remove_additional_properties(value: &mut Value) {
}

#[async_trait]
impl LlmGenerationClient for Client {
impl LlmClient for Client {
async fn generate<'req>(
&self,
request: LlmGenerateRequest<'req>,
Expand Down Expand Up @@ -79,7 +79,7 @@ impl LlmGenerationClient for Client {
let api_key = &self.api_key;
let url = format!(
"https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent?key={}",
encode(&self.model),
encode(request.model),
encode(api_key)
);

Expand Down
12 changes: 3 additions & 9 deletions src/llm/litellm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,13 @@ use async_openai::config::OpenAIConfig;
pub use super::openai::Client;

impl Client {
pub async fn new_litellm(spec: super::LlmSpec) -> anyhow::Result<Self> {
let address = spec
.address
.clone()
.unwrap_or_else(|| "http://127.0.0.1:4000".to_string());
pub async fn new_litellm(address: Option<String>) -> anyhow::Result<Self> {
let address = address.unwrap_or_else(|| "http://127.0.0.1:4000".to_string());
let api_key = std::env::var("LITELLM_API_KEY").ok();
let mut config = OpenAIConfig::new().with_api_base(address);
if let Some(api_key) = api_key {
config = config.with_api_key(api_key);
}
Ok(Client::from_parts(
OpenAIClient::with_config(config),
spec.model,
))
Ok(Client::from_parts(OpenAIClient::with_config(config)))
}
}
37 changes: 18 additions & 19 deletions src/llm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ pub enum LlmApiType {

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmSpec {
api_type: LlmApiType,
address: Option<String>,
model: String,
pub api_type: LlmApiType,
pub address: Option<String>,
pub model: String,
}

#[derive(Debug)]
Expand All @@ -34,6 +34,7 @@ pub enum OutputFormat<'a> {

#[derive(Debug)]
pub struct LlmGenerateRequest<'a> {
pub model: &'a str,
pub system_prompt: Option<Cow<'a, str>>,
pub user_prompt: Cow<'a, str>,
pub output_format: Option<OutputFormat<'a>>,
Expand All @@ -45,7 +46,7 @@ pub struct LlmGenerateResponse {
}

#[async_trait]
pub trait LlmGenerationClient: Send + Sync {
pub trait LlmClient: Send + Sync {
async fn generate<'req>(
&self,
request: LlmGenerateRequest<'req>,
Expand All @@ -61,25 +62,23 @@ mod ollama;
mod openai;
mod openrouter;

pub async fn new_llm_generation_client(spec: LlmSpec) -> Result<Box<dyn LlmGenerationClient>> {
let client = match spec.api_type {
LlmApiType::Ollama => {
Box::new(ollama::Client::new(spec).await?) as Box<dyn LlmGenerationClient>
}
LlmApiType::OpenAi => {
Box::new(openai::Client::new(spec).await?) as Box<dyn LlmGenerationClient>
}
LlmApiType::Gemini => {
Box::new(gemini::Client::new(spec).await?) as Box<dyn LlmGenerationClient>
}
pub async fn new_llm_generation_client(
api_type: LlmApiType,
address: Option<String>,
) -> Result<Box<dyn LlmClient>> {
let client = match api_type {
LlmApiType::Ollama => Box::new(ollama::Client::new(address).await?) as Box<dyn LlmClient>,
LlmApiType::OpenAi => Box::new(openai::Client::new(address).await?) as Box<dyn LlmClient>,
LlmApiType::Gemini => Box::new(gemini::Client::new(address).await?) as Box<dyn LlmClient>,
LlmApiType::Anthropic => {
Box::new(anthropic::Client::new(spec).await?) as Box<dyn LlmGenerationClient>
Box::new(anthropic::Client::new(address).await?) as Box<dyn LlmClient>
}
LlmApiType::LiteLlm => {
Box::new(litellm::Client::new_litellm(spec).await?) as Box<dyn LlmGenerationClient>
Box::new(litellm::Client::new_litellm(address).await?) as Box<dyn LlmClient>
}
LlmApiType::OpenRouter => {
Box::new(openrouter::Client::new_openrouter(address).await?) as Box<dyn LlmClient>
}
LlmApiType::OpenRouter => Box::new(openrouter::Client::new_openrouter(spec).await?)
as Box<dyn LlmGenerationClient>,
};
Ok(client)
}
12 changes: 5 additions & 7 deletions src/llm/ollama.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
use super::LlmGenerationClient;
use super::LlmClient;
use anyhow::Result;
use async_trait::async_trait;
use schemars::schema::SchemaObject;
use serde::{Deserialize, Serialize};

pub struct Client {
generate_url: String,
model: String,
reqwest_client: reqwest::Client,
}

Expand All @@ -33,27 +32,26 @@ struct OllamaResponse {
const OLLAMA_DEFAULT_ADDRESS: &str = "http://localhost:11434";

impl Client {
pub async fn new(spec: super::LlmSpec) -> Result<Self> {
let address = match &spec.address {
pub async fn new(address: Option<String>) -> Result<Self> {
let address = match &address {
Some(addr) => addr.trim_end_matches('/'),
None => OLLAMA_DEFAULT_ADDRESS,
};
Ok(Self {
generate_url: format!("{}/api/generate", address),
model: spec.model,
reqwest_client: reqwest::Client::new(),
})
}
}

#[async_trait]
impl LlmGenerationClient for Client {
impl LlmClient for Client {
async fn generate<'req>(
&self,
request: super::LlmGenerateRequest<'req>,
) -> Result<super::LlmGenerateResponse> {
let req = OllamaRequest {
model: &self.model,
model: request.model,
prompt: request.user_prompt.as_ref(),
format: request.output_format.as_ref().map(
|super::OutputFormat::JsonSchema { schema, .. }| {
Expand Down
16 changes: 7 additions & 9 deletions src/llm/openai.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::api_bail;

use super::LlmGenerationClient;
use super::LlmClient;
use anyhow::Result;
use async_openai::{
Client as OpenAIClient,
Expand All @@ -16,16 +16,15 @@ use async_trait::async_trait;

pub struct Client {
client: async_openai::Client<OpenAIConfig>,
model: String,
}

impl Client {
pub(crate) fn from_parts(client: async_openai::Client<OpenAIConfig>, model: String) -> Self {
Self { client, model }
pub(crate) fn from_parts(client: async_openai::Client<OpenAIConfig>) -> Self {
Self { client }
}

pub async fn new(spec: super::LlmSpec) -> Result<Self> {
if let Some(address) = spec.address {
pub async fn new(address: Option<String>) -> Result<Self> {
if let Some(address) = address {
api_bail!("OpenAI doesn't support custom API address: {address}");
}
// Verify API key is set
Expand All @@ -35,13 +34,12 @@ impl Client {
Ok(Self {
// OpenAI client will use OPENAI_API_KEY env variable by default
client: OpenAIClient::new(),
model: spec.model,
})
}
}

#[async_trait]
impl LlmGenerationClient for Client {
impl LlmClient for Client {
async fn generate<'req>(
&self,
request: super::LlmGenerateRequest<'req>,
Expand Down Expand Up @@ -70,7 +68,7 @@ impl LlmGenerationClient for Client {

// Create the chat completion request
let request = CreateChatCompletionRequest {
model: self.model.clone(),
model: request.model.to_string(),
messages,
response_format: match request.output_format {
Some(super::OutputFormat::JsonSchema { name, schema }) => {
Expand Down
12 changes: 3 additions & 9 deletions src/llm/openrouter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,13 @@ use async_openai::config::OpenAIConfig;
pub use super::openai::Client;

impl Client {
pub async fn new_openrouter(spec: super::LlmSpec) -> anyhow::Result<Self> {
let address = spec
.address
.clone()
.unwrap_or_else(|| "https://openrouter.ai/api/v1".to_string());
pub async fn new_openrouter(address: Option<String>) -> anyhow::Result<Self> {
let address = address.unwrap_or_else(|| "https://openrouter.ai/api/v1".to_string());
let api_key = std::env::var("OPENROUTER_API_KEY").ok();
let mut config = OpenAIConfig::new().with_api_base(address);
if let Some(api_key) = api_key {
config = config.with_api_key(api_key);
}
Ok(Client::from_parts(
OpenAIClient::with_config(config),
spec.model,
))
Ok(Client::from_parts(OpenAIClient::with_config(config)))
}
}
12 changes: 7 additions & 5 deletions src/ops/functions/extract_by_llm.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
use crate::prelude::*;

use crate::llm::{
LlmGenerateRequest, LlmGenerationClient, LlmSpec, OutputFormat, new_llm_generation_client,
};
use crate::llm::{LlmClient, LlmGenerateRequest, LlmSpec, OutputFormat, new_llm_generation_client};
use crate::ops::sdk::*;
use base::json_schema::build_json_schema;
use schemars::schema::SchemaObject;
Expand All @@ -21,7 +19,8 @@ pub struct Args {

struct Executor {
args: Args,
client: Box<dyn LlmGenerationClient>,
client: Box<dyn LlmClient>,
model: String,
output_json_schema: SchemaObject,
system_prompt: String,
value_extractor: base::json_schema::ValueExtractor,
Expand Down Expand Up @@ -50,11 +49,13 @@ Output only the JSON without any additional messages or explanations."

impl Executor {
async fn new(spec: Spec, args: Args) -> Result<Self> {
let client = new_llm_generation_client(spec.llm_spec).await?;
let client =
new_llm_generation_client(spec.llm_spec.api_type, spec.llm_spec.address).await?;
let schema_output = build_json_schema(spec.output_type, client.json_schema_options())?;
Ok(Self {
args,
client,
model: spec.llm_spec.model,
output_json_schema: schema_output.schema,
system_prompt: get_system_prompt(&spec.instruction, schema_output.extra_instructions),
value_extractor: schema_output.value_extractor,
Expand All @@ -75,6 +76,7 @@ impl SimpleFunctionExecutor for Executor {
async fn evaluate(&self, input: Vec<Value>) -> Result<Value> {
let text = self.args.text.value(&input)?.as_str()?;
let req = LlmGenerateRequest {
model: &self.model,
system_prompt: Some(Cow::Borrowed(&self.system_prompt)),
user_prompt: Cow::Borrowed(text),
output_format: Some(OutputFormat::JsonSchema {
Expand Down